In [None]:
!kill -9 1237213      

In [41]:
!nvidia-smi

Wed Jun  7 06:23:10 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000                On | 00000000:23:00.0 Off |                  Off |
| 30%   49C    P2              126W / 300W|  18967MiB / 49140MiB |     21%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                         

In [2]:
import numpy as np
import argparse
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR100, CIFAR10

In [3]:
import clip

In [14]:
def count_parameter(model):
    return sum(p.numel() for p in model.parameters())

count_parameter(model.visual)

56259936

In [12]:
model_name = 'RN101'
batch_size = 100
epochs = 10

available_model = clip.available_models()

print(available_model)

assert model_name in available_model

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [13]:
model, preprocess = clip.load(model_name)

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print(preprocess)

100%|███████████████████████████████████████| 278M/278M [01:29<00:00, 3.26MiB/s]


Model parameters: 119,688,033
Input resolution: 224
Context length: 77
Vocab size: 49408
Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f4a3a0795e0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)


In [37]:
def accuracy(output, target):
    pred = output.topk(1, 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return float(correct[0].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())

In [38]:
class MyDataset(Dataset):
    def __init__(self, dataset, template):
        self.dataset = dataset
        self.classes = dataset.classes
        
        prompt = []
        
        for data in tqdm(dataset):
            prompt.append(template.format(self.classes[data[1]])) 
        
        self.prompt = clip.tokenize(prompt)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = self.dataset[idx][0]
        label = self.dataset[idx][1]
        text = self.prompt[idx]
        
        return image, label, text

In [20]:
prompt = 'this is a photo of a {}'

cifar10_train = CIFAR10('./cifar10_data', train=True, transform=preprocess, download=False)
cifar10_train_dataloader = DataLoader(MyDataset(cifar10_train, prompt), batch_size=batch_size, shuffle=False)

cifar10_valid = CIFAR10('./cifar10_data', train=False, transform=preprocess, download=False)
cifar10_valid_dataloader = DataLoader(MyDataset(cifar10_valid, prompt), batch_size=batch_size, shuffle=False)

100%|███████████████████████████████████████████████████████████████████████████| 50000/50000 [00:41<00:00, 1201.94it/s]
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:08<00:00, 1199.75it/s]


In [39]:
device = "cuda"

model = model.cuda()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)

In [40]:
for epoch in tqdm(range(epochs)):
    model.train()
    
    _loss= 0
    _correct = 0
    _total= 0
    
    for i, data in enumerate(cifar10_train_dataloader):
        inputs, labels, texts = data
        inputs, labels, texts = inputs.cuda(), labels.cuda(), texts.cuda()
        
        optimizer.zero_grad()
        
#         logits_per_image, logits_per_text = model(inputs, texts)
        
#         ground_truth = torch.arange(len(inputs),dtype=torch.long,device=device)

#         total_loss = (loss_img(logits_per_image,labels) + loss_txt(logits_per_text,labels))/2
#         total_loss.backward()
        logits = model.visual(inputs)
        
        loss = loss_img(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        _loss += loss.item()
        _total += labels.size(0)
        _correct += accuracy(logits, labels)

        
    print(f"----train----\nloss: {_loss / _total}    accuracy: {_correct / _total * 100}\n")
    
    model.eval()
    
    _loss= 0
    _correct = 0
    _total= 0
    
    best_acc = 0
    best_epoch = 0
    
    for i, data in enumerate(cifar10_valid_dataloader):
        with torch.no_grad():
            inputs, labels, texts = data
            inputs, labels, texts = inputs.cuda(), labels.cuda(), texts.cuda()

            optimizer.zero_grad()

    #         logits_per_image, logits_per_text = model(inputs, texts)

    #         ground_truth = torch.arange(len(inputs),dtype=torch.long,device=device)

    #         total_loss = (loss_img(logits_per_image,labels) + loss_txt(logits_per_text,labels))/2
    #         total_loss.backward()
            logits = model.visual(inputs)

            loss = loss_img(logits, labels)

            _loss += loss.item()
            _total += labels.size(0)
            _correct += accuracy(logits, labels)
            
    print(f"----test----\nloss: {_loss / _total}    accuracy: {_correct / _total * 100}\n")
    
    if best_acc < (_correct / _total * 100):
        best_acc = _correct / _total * 100
        best_epoch = epoch
        
print(f"Best Accuracy: {best_acc} Best EPOCH: {best_epoch}")

  0%|                                                                                            | 0/10 [00:00<?, ?it/s]

----train----
loss: nan    accuracy: 9.994



 10%|████████▎                                                                          | 1/10 [03:04<27:36, 184.00s/it]

----test----
loss: nan    accuracy: 10.0

----train----
loss: nan    accuracy: 10.0



 20%|████████████████▌                                                                  | 2/10 [06:08<24:33, 184.22s/it]

----test----
loss: nan    accuracy: 10.0



 20%|████████████████▌                                                                  | 2/10 [09:50<39:20, 295.12s/it]


KeyboardInterrupt: 