In [1]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import requests
import time
import numpy as np
import io
from io import BytesIO
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision.models as models
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import os
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import random
from tqdm import tqdm
import json
from torch.optim.lr_scheduler import CosineAnnealingLR
import threading
import torchvision.models as models
import torch.nn as nn
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel
from nltk.corpus import wordnet
from caption_transforms import SimCLRData_Caption_Transform
from image_transforms import SimCLRData_image_Transform
from dataset import FlickrDataset,Flickr30kDataset
from models import ResNetSimCLR,OpenAI_SIMCLR
from utils import get_gpu_stats,layerwise_trainable_parameters,count_trainable_parameters,get_gpu_memory
from metrics import inter_ContrastiveLoss, intra_ContrastiveLoss
from metrics import LARS,Optimizer_simclr
from logger import Logger
from train_fns import train, test


In [2]:
!nvidia-smi

Thu Apr 13 01:47:48 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  On   | 00000000:3B:00.0 Off |                  Off |
| N/A   35C    P0    36W / 250W |  12153MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE...  On   | 00000000:D8:00.0 Off |                  Off |
| N/A   35C    P0    37W / 250W |  11847MiB / 16160MiB |      0%      Defaul

In [2]:
torch.cuda.empty_cache()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.current_device())


0


In [None]:
torch.cuda

In [4]:
dataset = Flickr30kDataset('/work/08629/pradhakr/maverick2/cv_project/flickr30k-images', 
                           '/work/08629/pradhakr/maverick2/cv_project/flickr30k_captions/results_20130124.token',
                           caption_index_1=0,
                           caption_index_2=1,
                          image_transform=SimCLRData_image_Transform())
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [29783, 1000, 1000])
batch_size = 16
train_loader = DataLoader(train_set, 
                         batch_size=batch_size, 
                         shuffle=True, 
                         num_workers=4, 
                         pin_memory=True)
val_loader = DataLoader(val_set, 
                         batch_size=batch_size, 
                         shuffle=False, 
                         num_workers=4, 
                         pin_memory=True)
print(len(train_loader))
print(len(val_loader))

1862
63


In [None]:
for idx ,batch in enumerate(val_loader):
    s=batch
    break

In [None]:
resnet_model = ResNetSimCLR(
    model='resnet50',
    intra_projection_dim=128,
    inter_projection_dim =1024,
    layers_to_train=['layer3','layer4'],
    evaluate=False
).to(device)
gpt_model = OpenAI_SIMCLR(
    model='openai-gpt',
    intra_projection_dim=128,
    inter_projection_dim=1024,
    layers_to_train=['h.10','h.11'],
    evaluate=False
).to(device)

In [None]:
intra_image,inter_image=resnet_model(s[1],device)
intra_image1,inter_image1=resnet_model(s[2],device)
intra_cap,inter_cap=gpt_model(s[3],device)
intra_cap1,inter_cap1=gpt_model(s[4],device)

In [None]:
intra_loss=intra_ContrastiveLoss(device,temperature=0.07)
newinter_loss=inter_ContrastiveLoss(margin=0.2, max_violation=True)
print(intra_loss(intra_image,intra_image1, batch_size))
a,b=newinter_loss(inter_image,inter_image1,inter_cap,inter_cap1)
print(a,b)

In [5]:
def train1(dataloader,data_type, image_model, text_model, optimizer_image, optimizer_text, intra_criterion,inter_criterion,device,
          scheduler_image=None, scheduler_text=None, trade_off_ii=1, trade_off_cc=1,trade_off_ic=1,trade_off_ci=1):
    loss_epoch = 0

    for idx, batch in tqdm(enumerate(dataloader)):
        image_model.train()
        text_model.train()
        intra_contrastive_loss=0
        batch_size = batch[0].shape[0]
        if data_type=='flickr_travel':
            image1, image2, caption1, caption2 = batch[0], batch[1], batch[3], batch[4]
        if data_type=='flickr30k':
            image1, image2, caption1, caption2 = batch[1], batch[2], batch[3], batch[4]
            
        intra_image,inter_image = image_model(image1, device)
        intra_image1,inter_image1 = image_model(image2, device)
        intra_contrastive_loss+=(trade_off_ii * intra_criterion(intra_image, intra_image1, batch_size))
        del intra_image , intra_image1
        intra_cap,inter_cap = text_model(caption1, device)
        intra_cap1,inter_cap1 = text_model(caption2, device)
        intra_contrastive_loss+=(trade_off_cc * intra_criterion(intra_cap, intra_cap1, batch_size))
        #intra_contrastive_loss = (trade_off_ii * intra_criterion(intra_image, intra_image1, batch_size) +
                            #trade_off_cc * intra_criterion(intra_cap, intra_cap1, batch_size))
            
            
        ci_loss, ic_loss=inter_criterion(inter_image,inter_image1,inter_cap,inter_cap1)
        del  inter_image,inter_image1,inter_cap,inter_cap1
        inter_contrastive_loss= trade_off_ci*ci_loss + trade_off_ic*ic_loss
        
        total_loss = intra_contrastive_loss + inter_contrastive_loss
        
        total_loss.backward()
        optimizer_image.step()
        optimizer_text.step()

        optimizer_image.zero_grad()
        optimizer_text.zero_grad()
        print(total_loss.item())
        loss_epoch += total_loss.item()
        #del batch, intra_image, inter_image, intra_image1, inter_image1, intra_cap
        #del inter_cap, intra_cap1, inter_cap1, intra_contrastive_loss, ci_loss, ic_loss, inter_contrastive_loss
    if scheduler_image:
        scheduler_image.step()
    if scheduler_text:
        scheduler_text.step()
    epoch_loss = loss_epoch / len(dataloader)
    return epoch_loss
def test1(dataloader, data_type, image_model, text_model,intra_criterion,inter_criterion, device, trade_off_ii=1,
         trade_off_cc=1,trade_off_ic=1,trade_off_ci=1):

    loss_epoch = 0

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            image_model.eval()
            text_model.eval()
            batch_size = batch[0].shape[0]
            if data_type=='flickr_travel':
                image1, image2, caption1, caption2 = batch[0], batch[1], batch[3], batch[4]
            if data_type=='flickr30k':
                image1, image2, caption1, caption2 = batch[1], batch[2], batch[3], batch[4]
            

            intra_image,inter_image = image_model(image1, device)
            intra_image1,inter_image1 = image_model(image2, device)
            intra_cap,inter_cap = text_model(caption1, device)
            intra_cap1,inter_cap1 = text_model(caption2, device)

            intra_contrastive_loss = (trade_off_ii * intra_criterion(intra_image, intra_image1, batch_size) +
                                trade_off_cc * intra_criterion(intra_cap, intra_cap1, batch_size))


            ci_loss, ic_loss=inter_criterion(inter_image,inter_image1,inter_cap,inter_cap1)
            inter_contrastive_loss= trade_off_ci*ci_loss + trade_off_ic*ic_loss

            total_loss = intra_contrastive_loss + inter_contrastive_loss
            print(total_loss.item())
            loss_epoch += total_loss.item()

            #del batch, intra_image, inter_image, intra_image1, inter_image1, intra_cap
            #del inter_cap, intra_cap1, inter_cap1, intra_contrastive_loss, ci_loss, ic_loss, inter_contrastive_loss            
            torch.cuda.empty_cache()

    epoch_loss = loss_epoch / len(dataloader)
    return epoch_loss

In [6]:
projection_dim=128
encoder_last_layer=2048
image_learning_rate = 0.001
text_learning_rate=4e-5
momentum = 0.9
temperature = 0.07
weight_decay = 0.0001
optimizer_type = 'sgd'
total_epochs=100
trade_off_ii=1
trade_off_cc=1
trade_off_ic=1e-4
trade_off_ci=1e-4
image_layers=['layer3','layer4']
resnet_model = ResNetSimCLR(
    model='resnet50',
    intra_projection_dim=128,
    inter_projection_dim =1024,
    layers_to_train=image_layers,
    evaluate=False
).to(device)
print('resnet_trainable_Params',count_trainable_parameters(resnet_model))
get_gpu_memory()
gpt_model = OpenAI_SIMCLR(
    model='openai-gpt',
    intra_projection_dim=128,
    inter_projection_dim=1024,
    layers_to_train=['h.10','h.11'],
    evaluate=False
).to(device)
print('gpt_trainable_Params',count_trainable_parameters(gpt_model))
get_gpu_memory()
# Define loss function
intra_loss=intra_ContrastiveLoss(device,temperature=0.07)
newinter_loss=inter_ContrastiveLoss(margin=0.2, max_violation=True)
# Define optimizers and schedulers
optimizer_image = Optimizer_simclr(optimizer_name=optimizer_type,
                                   model_parameters=resnet_model.parameters(),
                                   lr=image_learning_rate,
                                   momentum=momentum,
                                   weight_decay=weight_decay)

scheduler_image = optimizer_image.scheduler
optimizer_image = optimizer_image.optimizer

optimizer_text = Optimizer_simclr(optimizer_name=optimizer_type,
                                  model_parameters=gpt_model.parameters(),
                                  lr=text_learning_rate,
                                  momentum=momentum,
                                  weight_decay=weight_decay)

scheduler_text = optimizer_text.scheduler
optimizer_text = optimizer_text.optimizer


resnet_trainable_Params 32816256
Free GPU memory: 16.64 GB


ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.


gpt_trainable_Params 16242816
Free GPU memory: 15.64 GB


In [8]:
for epoch in range(100):
    start = time.time()
    train_loss = train1(dataloader=train_loader, 
                           data_type='flickr30k',
                           image_model=resnet_model, 
                           text_model=gpt_model,
                           optimizer_image=optimizer_image, 
                           optimizer_text=optimizer_text, 
                           intra_criterion=intra_loss,
                           inter_criterion=newinter_loss,
                            device=device,
                           scheduler_image=scheduler_image,
                           scheduler_text=scheduler_text,
                           trade_off_ii=trade_off_ii, 
                           trade_off_cc=trade_off_cc,
                           trade_off_ic=trade_off_ic,
                           trade_off_ci=trade_off_ci)
    test_loss = test(dataloader=val_loader, 
                     data_type='flickr30k',
                     image_model=resnet_model,
                     text_model=gpt_model,
                     intra_criterion=intra_loss,
                     inter_criterion=newinter_loss,
                     device=device,
                     trade_off_ii=trade_off_ii,
                     trade_off_cc=trade_off_cc,
                     trade_off_ic=trade_off_ic,
                     trade_off_ci=trade_off_ci)
    end = time.time()
    print('trainloss',round(train_loss,3),'testloss',round(test_loss,3),'time',round(end-start,1))

2it [00:02,  1.16it/s]

4.795647621154785
4.735579967498779


4it [00:02,  2.55it/s]

5.295603275299072
4.997143745422363


6it [00:03,  2.31it/s]

3.4373855590820312
3.651033639907837


8it [00:03,  3.49it/s]

3.968688488006592
4.212647438049316


10it [00:05,  2.34it/s]

4.303959369659424
3.8869311809539795


12it [00:05,  3.45it/s]

3.6686627864837646
3.570733070373535


14it [00:06,  2.49it/s]

4.222115993499756
3.3865818977355957


16it [00:06,  3.57it/s]

4.1088948249816895
4.392673492431641


17it [00:07,  2.05it/s]

3.485539436340332


19it [00:08,  2.87it/s]

3.8443102836608887
4.312129974365234


20it [00:08,  3.41it/s]

3.347177028656006


21it [00:09,  1.98it/s]

3.3131980895996094


23it [00:09,  2.93it/s]

3.164569616317749
3.8016433715820312


24it [00:09,  3.46it/s]

2.9321517944335938


26it [00:11,  2.44it/s]

4.187251091003418
4.655463218688965


28it [00:11,  3.50it/s]

3.4963090419769287
2.788294792175293


30it [00:12,  2.45it/s]

2.79693603515625
3.7958173751831055


32it [00:13,  3.53it/s]

3.5298662185668945
3.0587475299835205


34it [00:14,  2.40it/s]

3.96560001373291
3.2139806747436523


36it [00:14,  3.45it/s]

4.126687049865723
3.333146572113037


38it [00:15,  2.48it/s]

3.1181583404541016
2.703996181488037


40it [00:16,  3.60it/s]

2.9693078994750977
2.7560946941375732


42it [00:17,  2.45it/s]

3.033487558364868
2.582414388656616


44it [00:17,  3.53it/s]

2.0925192832946777
3.0265676975250244


46it [00:18,  2.44it/s]

2.28644061088562
2.4475502967834473


48it [00:19,  3.52it/s]

2.4545540809631348
2.5773417949676514


49it [00:20,  2.01it/s]

2.7749149799346924


51it [00:20,  2.88it/s]

2.1034891605377197
2.078591823577881


52it [00:20,  3.41it/s]

2.186657428741455


53it [00:21,  2.05it/s]

2.880791664123535


55it [00:22,  2.92it/s]

3.627610921859741
2.5199289321899414


56it [00:22,  3.45it/s]

2.0205342769622803


57it [00:23,  2.05it/s]

2.1251494884490967


59it [00:23,  2.94it/s]

3.3520138263702393
2.616237163543701


60it [00:23,  3.50it/s]

1.7887364625930786


61it [00:24,  2.04it/s]

2.1552324295043945


63it [00:25,  2.98it/s]

2.4986863136291504
2.879007577896118


64it [00:25,  3.51it/s]

2.290408134460449


65it [00:26,  2.04it/s]

1.8372350931167603


67it [00:26,  3.02it/s]

2.251201868057251
1.9373385906219482


68it [00:26,  3.57it/s]

3.0222885608673096


69it [00:27,  2.02it/s]

3.0628910064697266


71it [00:28,  2.90it/s]

1.7971960306167603
1.7242618799209595


72it [00:28,  3.44it/s]

2.5959534645080566


73it [00:29,  2.11it/s]

1.9714428186416626


75it [00:29,  2.90it/s]

2.0113942623138428
1.7254277467727661


76it [00:29,  3.43it/s]

1.8969669342041016


77it [00:30,  2.12it/s]

2.327840566635132


79it [00:31,  3.02it/s]

1.8093420267105103
2.3745248317718506


80it [00:31,  3.56it/s]

1.1169525384902954


81it [00:32,  2.10it/s]

2.163857936859131


83it [00:32,  2.92it/s]

2.161665916442871
1.7516289949417114


84it [00:32,  3.46it/s]

1.9151155948638916


85it [00:33,  2.14it/s]

2.240556240081787


87it [00:34,  2.89it/s]

1.3247919082641602
3.0099053382873535


88it [00:34,  3.44it/s]

1.8672453165054321


89it [00:35,  2.20it/s]

1.6379064321517944


91it [00:35,  2.84it/s]

1.9886784553527832
2.7670154571533203


92it [00:35,  3.39it/s]

2.0457763671875


93it [00:36,  2.19it/s]

2.4722859859466553


95it [00:37,  2.81it/s]

1.9390960931777954
2.8425962924957275


96it [00:37,  3.36it/s]

1.2108380794525146


97it [00:38,  2.26it/s]

1.8677574396133423


99it [00:38,  2.84it/s]

1.9104149341583252
2.009962558746338


100it [00:39,  3.39it/s]

1.917779803276062


101it [00:39,  2.31it/s]

1.9484374523162842


103it [00:40,  2.74it/s]

1.8980965614318848
2.5677499771118164


104it [00:40,  3.26it/s]

1.0007169246673584


105it [00:41,  2.37it/s]

2.092190980911255


107it [00:41,  2.90it/s]

1.5375926494598389
2.5697836875915527


108it [00:42,  3.41it/s]

1.7403897047042847


109it [00:42,  2.20it/s]

1.5686686038970947


111it [00:43,  2.87it/s]

1.510647177696228
2.4869322776794434


112it [00:43,  3.36it/s]

1.1314533948898315


113it [00:44,  2.22it/s]

1.8127169609069824


115it [00:44,  2.95it/s]

1.2385144233703613
1.3788193464279175


116it [00:45,  3.31it/s]

0.6318385601043701


117it [00:46,  2.15it/s]

2.862266778945923


119it [00:46,  2.95it/s]

0.8157999515533447
1.2920193672180176


120it [00:46,  3.32it/s]

1.9542388916015625


121it [00:47,  2.15it/s]

1.314186930656433


123it [00:47,  2.93it/s]

1.200817346572876
1.770222783088684


124it [00:48,  3.16it/s]

1.4227979183197021


125it [00:49,  2.19it/s]

1.750096082687378


127it [00:49,  2.92it/s]

1.5310114622116089
1.8866288661956787


128it [00:49,  3.30it/s]

3.100322723388672


129it [00:50,  2.15it/s]

1.6451048851013184


131it [00:51,  3.03it/s]

2.341808557510376
1.786875605583191


132it [00:51,  3.07it/s]

1.0694547891616821


133it [00:52,  2.14it/s]

1.4215999841690063


135it [00:52,  2.89it/s]

1.1737401485443115
0.897241473197937


136it [00:52,  3.04it/s]

1.6845240592956543


137it [00:53,  2.16it/s]

1.7669363021850586


139it [00:54,  2.96it/s]

1.1525158882141113
1.52326500415802


140it [00:54,  3.10it/s]

1.3483542203903198


141it [00:55,  2.14it/s]

1.560400366783142


143it [00:55,  2.92it/s]

1.779638409614563
1.9618968963623047


144it [00:55,  3.18it/s]

2.109067916870117


144it [00:56,  2.55it/s]


KeyboardInterrupt: 

In [None]:
!nvidia-smi

In [None]:
index=7
print(a[index])
print(b[index])

In [None]:
print(b[index])