In [1]:
from conch.open_clip_custom import create_model_from_pretrained, tokenize, get_tokenizer
import torch
from torch import nn
import os
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import glob

import matplotlib.pyplot as plt
import numpy as np

import tqdm

import skimage

from torch.utils.data import DataLoader, Dataset

# show all jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"



In [2]:
root = Path('../').resolve()
os.chdir(root)

In [3]:
class CONCHModelForFinetuning(nn.Module):
    def __init__(self, num_classes=8, config={'hidden_size': 512}):
        super().__init__()
        self.config = config
        self.model = self.make_conch()
        self.fc = nn.Linear(self.config['hidden_size'], num_classes)

    def make_conch(self):
        # Load the model from "create_model_from_pretrained"
        model_cfg = 'conch_ViT-B-16'
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # checkpoint_path = 'checkpoints/CONCH/pytorch_model.bin'
        checkpoint_path = 'C:\\Users\\Vivian\\Documents\\CONCH\\checkpoints\\conch\\pytorch_model.bin' 
        model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, device=device)
        
        return model
        
    def forward(self, x):
        out, h = self.model.visual(x)
        return self.fc(out)

In [4]:
model = CONCHModelForFinetuning().to('cuda')

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


In [5]:
'SOB_B_A-14-22549AB-40-001'

'SOB_B_A-14-22549AB-40-001'

In [6]:

os.listdir('/Users/Vivian/Documents/CONCH/')

['.git',
 '.gitignore',
 'basics_usage.ipynb',
 'binary-predictions.csv',
 'checkpoints',
 'CISC882_scripts',
 'conch',
 'conch.egg-info',
 'conch.jpg',
 'CONCH_finetuning.ipynb',
 'conch_hr.png',
 'CONCH_kfold.ipynb',
 'conch_kfold2.png',
 'conch_model1.py',
 'CONCH_scripts',
 'CONCH_shuffle-LR.ipynb',
 'CONCH_shuffle.ipynb',
 'CONCH_zeroshot-images.ipynb',
 'data_embeddings',
 'data_features',
 'data_tensors',
 'docs',
 'finetune_conch.ipynb',
 'Folds',
 'LICENSE',
 'MANIFEST.in',
 'notebooks',
 'preprocess_WSIs.ipynb',
 'prompts',
 'pyproject.toml',
 'README.md',
 'zeroshot_classification_example_starter copy.ipynb',
 'zeroshot_classification_example_starter.ipynb']

In [7]:
def make_metadata(num_folds=5):
    metadata = pd.DataFrame()
    for fold in range(num_folds):
        print(f'Fold {fold+1}')
        for mode in ['train', 'test']:
            pathname = f'/Users/Vivian/Documents/CONCH/Folds/Fold {fold+1}/{mode}/'
            images = os.listdir(pathname)
            for image in images:
                label = image.split('-')[0].replace('SOB_', '')
                class_name, subclass_name = label.split('_')
                #metadata = metadata.append({'image': pathname+image, 'fold': fold, 'mode': mode, 'class': class_name, 'subclass': subclass_name}, ignore_index=True)
                metadata = pd.concat([metadata, pd.DataFrame({'image': pathname+image, 'fold': fold, 'mode': mode, 'class': class_name, 'subclass': subclass_name}, index=[0])], ignore_index=True)
    return metadata

In [8]:
def make_metadata(fold):
    metadata = pd.DataFrame()
    for mode in ['train', 'test']:
        pathname = f'/Users/Vivian/Documents/CONCH/Folds/Fold {fold}/{mode}/'
        images = os.listdir(pathname)
        for image in images:
            if not image.startswith('SOB'):
                continue
            label = image.split('-')[0].replace('SOB_', '')
            class_name, subclass_name = label.split('_')
            #metadata = metadata.append({'image': pathname+image, 'fold': fold, 'mode': mode, 'class': class_name, 'subclass': subclass_name}, ignore_index=True)
            metadata = pd.concat([metadata, pd.DataFrame({'image': pathname+image, 'fold': fold, 'mode': mode, 'class': class_name, 'subclass': subclass_name}, index=[0])], ignore_index=True)
        metadata.to_csv(f'/Users/Vivian/Documents/CONCH/Folds/Fold {fold}/{mode}/metadata.csv', index=False)
    return metadata

In [9]:
#for fold in range(5):
#    make_metadata(fold+1)

In [10]:
"SOB_B_A-14-22549AB-100-001.png".split('-')[0].replace('SOB_', '')

'B_A'

In [11]:
# Custom Dataset class
class HistopathologyDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.label_map = {
            'B_A': 0,
            'B_F': 1,
            'B_PT': 2,
            'B_TA': 3,
            'M_DC': 4,
            'M_LC': 5,
            'M_MC': 6,
            'M_PC': 7
        }  # Example mapping of subclasses to numerical labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx]['image']
        class_name = self.data.iloc[idx]['class']
        subclass_name = self.data.iloc[idx]['subclass']
        label = self.label_map[class_name + '_' + subclass_name]
        image = plt.imread(img_path)
        image = skimage.transform.resize(image, (224, 224))
        image = image.transpose((2, 0, 1))
        if self.transform:
            image = self.transform(image)
        return image, label



In [12]:
train_data = HistopathologyDataset('/Users/Vivian/Documents/CONCH/Folds/Fold 1/train/metadata.csv')
test_data = HistopathologyDataset('/Users/Vivian/Documents/CONCH/Folds/Fold 1/test/metadata.csv')

In [13]:
# make a dataloder for me please 
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=True)

In [15]:
# now please write me a training loop to finetuen the model
# and save the best model to disk

# define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# define the loss function
criterion = torch.nn.CrossEntropyLoss()

# define the number of epochs
num_epochs = 10

# define the path to save the model
model_save_path = '/Users/Vivian/Documents/CONCH/_finetune_weights/fine_tuned_model.pth'

# define the best accuracy
best_accuracy = 0

device = 'cuda'

# start the training loop
for epoch in range(num_epochs):
    print(f'Starting epoch {epoch+1}/{num_epochs}')
    model.train()
    for batch in tqdm.tqdm(train_loader, desc='training'):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{num_epochs} done')
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for batch in tqdm.tqdm(test_loader, desc='testing'):
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f'Accuracy: {accuracy}')
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f'Model saved with accuracy: {accuracy}')

Starting epoch 1/10


CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:35<00:00,  6.54it/s]

Epoch 1/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:30<00:00,  4.70it/s]


Accuracy: 0.8389176887090656
Model saved with accuracy: 0.8389176887090656
Starting epoch 2/10


CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:35<00:00,  6.54it/s]

Epoch 2/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:27<00:00,  4.77it/s]

Accuracy: 0.8315842710835757
Starting epoch 3/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:34<00:00,  6.64it/s]

Epoch 3/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:27<00:00,  4.76it/s]

Accuracy: 0.8353774181312429
Starting epoch 4/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:35<00:00,  6.59it/s]

Epoch 4/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:27<00:00,  4.76it/s]

Accuracy: 0.8332279681375648
Starting epoch 5/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:34<00:00,  6.61it/s]

Epoch 5/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:27<00:00,  4.76it/s]

Accuracy: 0.8337337210772537
Starting epoch 6/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:34<00:00,  6.63it/s]

Epoch 6/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:28<00:00,  4.74it/s]

Accuracy: 0.8336072828423315
Starting epoch 7/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:34<00:00,  6.62it/s]

Epoch 7/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:29<00:00,  4.71it/s]

Accuracy: 0.8331015299026425
Starting epoch 8/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:35<00:00,  6.55it/s]

Epoch 8/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:29<00:00,  4.71it/s]

Accuracy: 0.8332279681375648
Starting epoch 9/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:35<00:00,  6.59it/s]

Epoch 9/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:29<00:00,  4.71it/s]

Accuracy: 0.8329750916677203
Starting epoch 10/10





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

training: 100%|██████████| 626/626 [01:35<00:00,  6.55it/s]

Epoch 10/10 done





CONCHModelForFinetuning(
  (model): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (trunk): Visi

testing: 100%|██████████| 989/989 [03:29<00:00,  4.72it/s]

Accuracy: 0.8322164622581869



