In [1]:
import argparse
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import requests
import time
import io
from io import BytesIO
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision.models as models
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import os
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import random
from tqdm import tqdm
import json
from torch.optim.lr_scheduler import CosineAnnealingLR
import threading
import torchvision.models as models
import torch.nn as nn
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel
from nltk.corpus import wordnet
from caption_transforms import SimCLRData_Caption_Transform
from image_transforms import SimCLRData_image_Transform
from dataset import FlickrDataset
from models import ResNetSimCLR,OpenAI_SIMCLR
from utils import get_gpu_stats,layerwise_trainable_parameters,count_trainable_parameters
from metrics import ContrastiveLoss
from metrics import LARS,Optimizer_simclr
from logger import Logger
from train_fns import train, test

In [18]:
def get_args() -> argparse.Namespace:
    """
    Parse command line arguments.

    Returns:
        argparse.Namespace: the parsed command line arguments
    """
    parser = argparse.ArgumentParser(description='SimCLR with Image and Caption Data')
    parser.add_argument('--seed',
                        type=int,
                        default=53,
                        help='random_seed')
    parser.add_argument('--modality_type',
                        type=str,
                        default='image_caption',
                        help='type of modality (image_caption or image)')
    parser.add_argument('--resnet_model',
                        type=str,
                        default='resnet50',
                        help='type of ResNet model to use (resnet18, resnet34, resnet50, resnet101, resnet152)')
    parser.add_argument('--gpt_model',
                        type=str,
                        default='openai-gpt',
                        help='type of GPT model to use (gpt, gpt2, gpt2-medium, gpt2-large)')
    parser.add_argument('--image_projection_dim',
                        type=int,
                        default=128,
                        help='dimension of the projected image embedding')
    parser.add_argument('--text_projection_dim',
                        type=int,
                        default=128,
                        help='dimension of the projected text embedding')
    parser.add_argument('--resnet_layer',
                        type=str,
                        default='layer4',
                        help='which ResNet layer to use as the encoder')
    parser.add_argument('--gpt_layer',
                        type=str,
                        default='h.11',
                        help='which GPT layer to use as the encoder')
    parser.add_argument('--temperature',
                        type=float,
                        default=0.07,
                        help='temperature parameter for contrastive loss')
    parser.add_argument('--total_epochs',
                        type=int,
                        default=100,
                        help='number of total epochs to train the model')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='batch size for training')
    parser.add_argument('--optimizer_type',
                        type=str,
                        default='sgd',
                        help='type of optimizer to use (adam, sgd, lars)')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.03,
                        help='learning rate for optimizer')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.09,
                        help='momentum for optimizer')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-4,
                        help='weight decay for optimizer')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataloader')
    parser.add_argument('--trade_off_ii',
                        type=float,
                        default=1,
                        help='trade-off weight for image-image similarity loss')
    parser.add_argument('--trade_off_cc',
                        type=float,
                        default=1,
                        help='trade-off weight for caption-caption similarity loss')
    parser.add_argument('--graph_save_dir',
                        type=str,
                        default='/home1/08629/pradhakr/cv_project/graphs/image_caption',
                        help='directory to save the loss graphs')
    parser.add_argument('--trial_number',
                        type=int,
                        help='trial number')

    args = parser.parse_args()

    return args



def main(args):

    random.seed(args.seed)
    torch.manual_seed(args.seed)


    # Create train and test datasets with image and caption transforms
    train_dataset = FlickrDataset('data/', "data/train", 'train',
                                  image_transform=SimCLRData_image_Transform(),
                                  caption_transform=SimCLRData_Caption_Transform())

    test_dataset = FlickrDataset('data/', "data/test", 'test',
                                 image_transform=SimCLRData_image_Transform(),
                                 caption_transform=SimCLRData_Caption_Transform())

    # Set the batch size and create train and test data loaders
    batch_size = args.batch_size
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=args.num_workers,
                                  pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=args.num_workers,
                                 pin_memory=True)

    # Set device to CUDA if available, otherwise to CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize ResNetSimCLR model
    model_resnet = ResNetSimCLR(
            model=args.resnet_model,
            projection_dim=args.image_projection_dim,
            layers_to_train=[args.resnet_layer]
        ).to(device)

        # Initialize OpenAI_SIMCLR model
    gpt_model = OpenAI_SIMCLR(
            model=args.gpt_model,
            projection_dim=args.text_projection_dim,
            layers_to_train=[args.gpt_layer]
        ).to(device)

        # Define loss function
    NXTENT_loss = ContrastiveLoss(device, temperature=args.temperature)

        # Define optimizers and schedulers
    optimizer_image = Optimizer_simclr(optimizer_name=args.optimizer_type,
                                           model_parameters=model_resnet.parameters(),
                                           lr=args.learning_rate,
                                           momentum=args.momentum,
                                           weight_decay=args.weight_decay)

    scheduler_image = optimizer_image.scheduler
    optimizer_image = optimizer_image.optimizer

    optimizer_text = Optimizer_simclr(optimizer_name=args.optimizer_type,
                                          model_parameters=gpt_model.parameters(),
                                          lr=args.learning_rate,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay)

    scheduler_text = optimizer_text.scheduler
    optimizer_text = optimizer_text.optimizer

    # Initialize trial number

    # Define paths for logs and files
    log_dir = os.path.join(os.getenv('WORK'), 'cv_project')
    image_caption_filename = os.path.join(log_dir, args.modality_type)
    train_log = os.path.join(image_caption_filename, f'train_{args.trial_number}.log')
    image_model_log = os.path.join(image_caption_filename, f'image_model_{args.trial_number}.pth')
    text_model_log = os.path.join(image_caption_filename, f'text_model_{args.trial_number}.pth')
    graph_save_dir = os.path.join(args.graph_save_dir, args.modality_type)

    # Create a logger object and start training
    logger_save = Logger(train_log, image_model_log, text_model_log, args.optimizer_type,
                         args.learning_rate, args.weight_decay, batch_size, args.momentum, args.temperature,
                         args.total_epochs)
    logger_save.start_training()
    # Loop through epochs and train the models
    for epoch in tqdm(range(args.total_epochs)):

        start = time.time()

        # Train the models and get the loss
        train_loss = train(dataloader=train_loader, image_model=model_resnet, text_model=gpt_model,
                           optimizer_image=optimizer_image, optimizer_text=optimizer_text, criterion=NXTENT_loss,
                           scheduler_image=scheduler_image, scheduler_text=scheduler_text, device=device,
                           trade_off_ii=args.trade_off_ii, trade_off_cc=args.trade_off_cc)

        # Test the models and get the loss
        test_loss = test(dataloader=test_loader, image_model=model_resnet, text_model=gpt_model, criterion=NXTENT_loss, device=device
                         ,trade_off_ii=args.trade_off_ii, trade_off_cc=args.trade_off_cc)

        end = time.time()

        # Log the results of the epoch
        logger_save.log(epoch + 1, model_resnet, gpt_model, train_loss, test_loss, end - start)

    # End training and plot the losses
    logger_save.end_training()
    logger_save.plot_losses(args.trial_number, graph_save_dir)
    
    
if __name__ == '__main__':
    # Parse command-line arguments
    args = get_args()

    # Call the main function with the parsed arguments
    main(args)

usage: ipykernel_launcher.py [-h] [--seed SEED]
                             [--modality_type MODALITY_TYPE]
                             [--resnet_model RESNET_MODEL]
                             [--gpt_model GPT_MODEL]
                             [--image_projection_dim IMAGE_PROJECTION_DIM]
                             [--text_projection_dim TEXT_PROJECTION_DIM]
                             [--resnet_layer RESNET_LAYER]
                             [--gpt_layer GPT_LAYER]
                             [--temperature TEMPERATURE]
                             [--total_epochs TOTAL_EPOCHS]
                             [--batch_size BATCH_SIZE]
                             [--optimizer_type OPTIMIZER_TYPE]
                             [--learning_rate LEARNING_RATE]
                             [--momentum MOMENTUM]
                             [--weight_decay WEIGHT_DECAY]
                             [--num_workers NUM_WORKERS]
                             [--trade_off_ii TRADE_OFF_II

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
