In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import os, sys
import glob
import numpy as np
import torch
import pandas as pd
import random
import torch.nn as nn
import pickle
import torch.nn.functional as F

In [3]:
from mstcn_model import *
from utility.adaptive_data_loader import Breakfast, collate_fn_override
from utils import calculate_mof, dotdict

In [4]:
os.environ["WANDB_API_KEY"] = "992b3b1371ba79f48484cfca522b3786d7fa52c2"
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdipika_singhania[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
seed = 42

# Ensure deterministic behavior
def set_seed():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
set_seed()

# Device configuration
os.environ['CUDA_VISIBLE_DEVICES']='1'
# os.environ['CUDA_LAUNCH_BLOCKING']='6'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
config = dotdict(
    epochs=500,
    num_class=48,
    batch_size=8,
    learning_rate=5e-4,
    weight_decay=0,
    dataset="Breakfast",
    architecture="unet-ensemble",
    features_file_name="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/",
    chunk_size=1,
    max_frames_per_video=1200,
    feature_size=2048,
    ground_truth_files_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/",
    label_id_csv="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv",
    gamma=0.1,
    step_size=500,
    split=2,
#     output_dir="/mnt/data/ar-datasets/dipika/breakfast/ms_tcn/data/breakfast/results/unsuper-finetune-split2-0.05-data-llr/",
    output_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split2/",
    project_name="breakfast-split-2",
    train_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split{}.bundle",
    test_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split{}.bundle",
    all_files="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt",
    cutoff=8,
    data_per = 0.2,
    budget=40,
    semi_supervised_split="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/semi_supervised/train.split{}_amt{}.bundle")

config.train_split_file = config.train_split_file.format(config.split)
config.semi_supervised_split = config.semi_supervised_split.format(config.split, config.data_per)
config.test_split_file = config.test_split_file.format(config.split)

if not os.path.exists(config.output_dir):
    os.mkdir(config.output_dir)

print(config)

{'epochs': 500, 'num_class': 48, 'batch_size': 8, 'learning_rate': 0.0005, 'weight_decay': 0, 'dataset': 'Breakfast', 'architecture': 'unet-ensemble', 'features_file_name': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/', 'chunk_size': 1, 'max_frames_per_video': 1200, 'feature_size': 2048, 'ground_truth_files_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/', 'label_id_csv': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv', 'gamma': 0.1, 'step_size': 500, 'split': 2, 'output_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split2/', 'project_name': 'breakfast-split-2', 'train_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split2.bundle', 'test_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split2.bundle', 'all_files': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt', 'cutoff': 8, 'data_per': 0.2, 'budget': 40, 'semi_supe

In [7]:
df=pd.read_csv(config.label_id_csv)
label_id_to_label_name = {}
label_name_to_label_id_dict = {}
for i, ele in df.iterrows():
    label_id_to_label_name[ele.label_id] = ele.label_name
    label_name_to_label_id_dict[ele.label_name] = ele.label_id

In [8]:
traindataset = Breakfast(config, fold='train', fold_file_name=config.train_split_file)
testdataset = Breakfast(config, fold='test', fold_file_name=config.test_split_file)

Number of videos logged in train fold is 1261
Number of videos not found in train fold is 0
Number of videos logged in test fold is 451
Number of videos not found in test fold is 0


In [9]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
trainloader = torch.utils.data.DataLoader(dataset=traindataset,
                                          batch_size=config.batch_size, 
                                          shuffle=True,
                                          pin_memory=True, num_workers=4, 
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)
testloader = torch.utils.data.DataLoader(dataset=testdataset,
                                          batch_size=config.batch_size, 
                                          shuffle=False,
                                          pin_memory=True, num_workers=4,
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)

In [10]:
set_seed()
model = MultiStageModel(num_stages=4, num_layers=10, num_f_maps=64, dim=2048, num_classes=48).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# Requires loaded_vidid_selected_frames, boundaries_dict
ce_criterion = nn.CrossEntropyLoss(ignore_index=-100)
mse_criterion = nn.MSELoss(reduction='none')

In [11]:
pseudo_labels_dir = "/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/length_segmentation_output/"
def get_single_random(output_p, video_ids):
    # Generate target for only timestamps. Do not generate pseudo labels at first 30 epochs.
    boundary_target_tensor = torch.ones((output_p.shape[0], output_p.shape[2]), dtype=torch.long, 
                                        device=output_p.device) * (-100)
    for iter_num, cur_vidid in enumerate(video_ids):
        pseudo_l = open(pseudo_labels_dir + cur_vidid + ".txt").read().split("\n")[0:-1]
        pseudo_l = [label_name_to_label_id_dict[ele] for ele in pseudo_l]
        abc = torch.tensor(pseudo_l).to(torch.long).to(boundary_target_tensor.device)
        frame_idx_tensor = torch.arange(0, len(pseudo_l), 1).to(device)
        boundary_target_tensor[iter_num, frame_idx_tensor] = abc

    return boundary_target_tensor

In [12]:
best_val_acc = 0
best_epoch = -1
for epoch in range(100):
    print("Starting Training")
    model.train()
    for i, item in enumerate(trainloader):
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
        optimizer.zero_grad()
        
        middle_pred, predictions = model(item_0, src_mask_mse)
        psuedo_l = get_single_random(predictions[-1], item[4])
        loss = 0
        for p in predictions:
            loss += ce_criterion(p, psuedo_l)
            loss += 0.15 * torch.mean(torch.clamp(mse_criterion(F.log_softmax(p[:, :, 1:], dim=1), 
                                                                F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
                                        max=16) * src_mask_mse[:, :, 1:])
            

        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            with torch.no_grad():
                pred = torch.argmax(predictions[-1], dim=1)
                correct = float(torch.sum((pred == item_2) * src_mask).item())
                total = float(torch.sum(src_mask).item())
                print(f"Training:: Epoch {epoch}, Iteration {i}, Current loss {loss.item()}" +
                      f" Accuracy {correct * 100.0 / total}")
    # Calculating Expectation Step
    model.eval()

    print("Calculating Validation Data Accuracy")
    correct = 0.0
    total = 0.0
    for i, item in enumerate(testloader):
        with torch.no_grad():
            item_0 = item[0].to(device)
            item_1 = item[1].to(device)
            item_2 = item[2].to(device)
            src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
            src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

            middle_pred, predictions = model(item_0, src_mask_mse)

            pred = torch.argmax(predictions[-1], dim=1)
            correct += float(torch.sum((pred == item_2) * src_mask).item())
            total += float(torch.sum(src_mask).item())
    val_acc = correct * 100.0 / total
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save(model.state_dict(), config.output_dir + "ms-tcn-best-model.wt")
    torch.save(model.state_dict(), config.output_dir + "ms-tcn-last-model.wt")
    print(f"Validation:: Epoch {epoch}, Probability Accuracy {val_acc}")

Starting Training
Training:: Epoch 0, Iteration 0, Current loss 15.693472862243652 Accuracy 1.0729173785279846
Training:: Epoch 0, Iteration 10, Current loss 14.211379051208496 Accuracy 6.276704579647349
Training:: Epoch 0, Iteration 20, Current loss 13.300257682800293 Accuracy 9.14622178606477
Training:: Epoch 0, Iteration 30, Current loss 14.41281509399414 Accuracy 11.566045116703611
Training:: Epoch 0, Iteration 40, Current loss 12.287521362304688 Accuracy 19.669512256003195
Training:: Epoch 0, Iteration 50, Current loss 11.137622833251953 Accuracy 24.671816009144123
Training:: Epoch 0, Iteration 60, Current loss 11.745620727539062 Accuracy 19.477620270119434
Training:: Epoch 0, Iteration 70, Current loss 12.002619743347168 Accuracy 5.896333182435491
Training:: Epoch 0, Iteration 80, Current loss 12.47575569152832 Accuracy 11.876788672992921
Training:: Epoch 0, Iteration 90, Current loss 12.001688003540039 Accuracy 5.0972927241962775
Training:: Epoch 0, Iteration 100, Current loss 1

Training:: Epoch 5, Iteration 20, Current loss 5.061407089233398 Accuracy 53.11252055438102
Training:: Epoch 5, Iteration 30, Current loss 5.921935558319092 Accuracy 50.64615681357894
Training:: Epoch 5, Iteration 40, Current loss 5.548762321472168 Accuracy 52.94337507008036
Training:: Epoch 5, Iteration 50, Current loss 5.490309238433838 Accuracy 40.46831806601984
Training:: Epoch 5, Iteration 60, Current loss 5.163252830505371 Accuracy 43.79336931380108
Training:: Epoch 5, Iteration 70, Current loss 5.947635173797607 Accuracy 37.71370535462393
Training:: Epoch 5, Iteration 80, Current loss 5.593697547912598 Accuracy 49.99327414581652
Training:: Epoch 5, Iteration 90, Current loss 5.365443706512451 Accuracy 43.9320006296238
Training:: Epoch 5, Iteration 100, Current loss 5.97140645980835 Accuracy 53.53922452660054
Training:: Epoch 5, Iteration 110, Current loss 4.366034507751465 Accuracy 58.29713891761462
Training:: Epoch 5, Iteration 120, Current loss 5.584264755249023 Accuracy 39.82

Training:: Epoch 10, Iteration 50, Current loss 5.959776401519775 Accuracy 51.14178683224714
Training:: Epoch 10, Iteration 60, Current loss 4.2059807777404785 Accuracy 64.28117792559415
Training:: Epoch 10, Iteration 70, Current loss 4.595099449157715 Accuracy 53.57114228456914
Training:: Epoch 10, Iteration 80, Current loss 3.8198883533477783 Accuracy 66.82602921646746
Training:: Epoch 10, Iteration 90, Current loss 4.246981620788574 Accuracy 53.6848792884371
Training:: Epoch 10, Iteration 100, Current loss 5.0094685554504395 Accuracy 43.80563798219585
Training:: Epoch 10, Iteration 110, Current loss 4.2515106201171875 Accuracy 63.09654531632739
Training:: Epoch 10, Iteration 120, Current loss 4.159738063812256 Accuracy 67.70468230430484
Training:: Epoch 10, Iteration 130, Current loss 6.123836517333984 Accuracy 38.18364274962368
Training:: Epoch 10, Iteration 140, Current loss 4.195005416870117 Accuracy 66.78538162292477
Training:: Epoch 10, Iteration 150, Current loss 4.48370027542

Training:: Epoch 15, Iteration 70, Current loss 3.196843147277832 Accuracy 58.04139414031001
Training:: Epoch 15, Iteration 80, Current loss 3.788417339324951 Accuracy 58.47253775223841
Training:: Epoch 15, Iteration 90, Current loss 4.681201934814453 Accuracy 56.672386724597466
Training:: Epoch 15, Iteration 100, Current loss 3.9759609699249268 Accuracy 55.2106959532052
Training:: Epoch 15, Iteration 110, Current loss 6.86570930480957 Accuracy 50.14230534661517
Training:: Epoch 15, Iteration 120, Current loss 5.8196940422058105 Accuracy 44.182648401826484
Training:: Epoch 15, Iteration 130, Current loss 4.762230396270752 Accuracy 49.19566644780039
Training:: Epoch 15, Iteration 140, Current loss 5.319911003112793 Accuracy 51.261436096479066
Training:: Epoch 15, Iteration 150, Current loss 5.391932964324951 Accuracy 52.984078068823834
Calculating Validation Data Accuracy
Validation:: Epoch 15, Probability Accuracy 50.38675063181008
Starting Training
Training:: Epoch 16, Iteration 0, Cu

Training:: Epoch 20, Iteration 90, Current loss 3.1407456398010254 Accuracy 65.94149043983515
Training:: Epoch 20, Iteration 100, Current loss 3.091240882873535 Accuracy 70.54435662907379
Training:: Epoch 20, Iteration 110, Current loss 2.8897907733917236 Accuracy 59.20873966177015
Training:: Epoch 20, Iteration 120, Current loss 2.893315076828003 Accuracy 70.64614983373541
Training:: Epoch 20, Iteration 130, Current loss 3.020702600479126 Accuracy 66.06771915247195
Training:: Epoch 20, Iteration 140, Current loss 3.5447447299957275 Accuracy 66.86967530023608
Training:: Epoch 20, Iteration 150, Current loss 4.137700080871582 Accuracy 57.346005430898956
Calculating Validation Data Accuracy
Validation:: Epoch 20, Probability Accuracy 52.8105191200232
Starting Training
Training:: Epoch 21, Iteration 0, Current loss 3.4592385292053223 Accuracy 58.119758143163644
Training:: Epoch 21, Iteration 10, Current loss 5.798767566680908 Accuracy 42.748284214775936
Training:: Epoch 21, Iteration 20, 

Training:: Epoch 25, Iteration 110, Current loss 4.253875255584717 Accuracy 52.74508254996445
Training:: Epoch 25, Iteration 120, Current loss 3.464813470840454 Accuracy 55.069240877388054
Training:: Epoch 25, Iteration 130, Current loss 3.3248579502105713 Accuracy 60.970129526830554
Training:: Epoch 25, Iteration 140, Current loss 3.8698794841766357 Accuracy 58.38920880334464
Training:: Epoch 25, Iteration 150, Current loss 2.973176956176758 Accuracy 63.05624636224139
Calculating Validation Data Accuracy
Validation:: Epoch 25, Probability Accuracy 55.46132493681899
Starting Training
Training:: Epoch 26, Iteration 0, Current loss 3.374666213989258 Accuracy 58.43350285097858
Training:: Epoch 26, Iteration 10, Current loss 3.181309700012207 Accuracy 59.44004690706538
Training:: Epoch 26, Iteration 20, Current loss 3.2264745235443115 Accuracy 64.0451803002666
Training:: Epoch 26, Iteration 30, Current loss 3.1254286766052246 Accuracy 69.07715881567891
Training:: Epoch 26, Iteration 40, Cu

Training:: Epoch 30, Iteration 130, Current loss 2.6234376430511475 Accuracy 73.23383963263865
Training:: Epoch 30, Iteration 140, Current loss 2.2357916831970215 Accuracy 67.4499970843781
Training:: Epoch 30, Iteration 150, Current loss 2.2764673233032227 Accuracy 66.3452566096423
Calculating Validation Data Accuracy
Validation:: Epoch 30, Probability Accuracy 51.04559390147906
Starting Training
Training:: Epoch 31, Iteration 0, Current loss 2.4612395763397217 Accuracy 68.08035714285714
Training:: Epoch 31, Iteration 10, Current loss 2.2597451210021973 Accuracy 63.360297042832514
Training:: Epoch 31, Iteration 20, Current loss 1.8844285011291504 Accuracy 71.71239837398375
Training:: Epoch 31, Iteration 30, Current loss 1.3557124137878418 Accuracy 76.35499718128372
Training:: Epoch 31, Iteration 40, Current loss 2.649503469467163 Accuracy 65.29836000790358
Training:: Epoch 31, Iteration 50, Current loss 1.7850632667541504 Accuracy 63.77431254191818
Training:: Epoch 31, Iteration 60, Cu

Training:: Epoch 35, Iteration 150, Current loss 2.2645649909973145 Accuracy 71.05367793240556
Calculating Validation Data Accuracy
Validation:: Epoch 35, Probability Accuracy 51.85731449641629
Starting Training
Training:: Epoch 36, Iteration 0, Current loss 1.7844520807266235 Accuracy 65.67769130998703
Training:: Epoch 36, Iteration 10, Current loss 2.8463640213012695 Accuracy 54.83798379837984
Training:: Epoch 36, Iteration 20, Current loss 1.8433791399002075 Accuracy 70.42580438610533
Training:: Epoch 36, Iteration 30, Current loss 2.213423252105713 Accuracy 59.485116653258245
Training:: Epoch 36, Iteration 40, Current loss 2.001297950744629 Accuracy 57.43120848788062
Training:: Epoch 36, Iteration 50, Current loss 1.604368805885315 Accuracy 59.79952962257811
Training:: Epoch 36, Iteration 60, Current loss 1.9590870141983032 Accuracy 55.709010864162465
Training:: Epoch 36, Iteration 70, Current loss 1.3344309329986572 Accuracy 72.15124130050899
Training:: Epoch 36, Iteration 80, Cur

KeyboardInterrupt: 

In [16]:
config.output_dir + "ms-tcn-best-model.wt"

'/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split1/ms-tcn-best-model.wt'

In [15]:
best_epoch

10

In [24]:
torch.save(model.state_dict(),
"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-speed/final-em-maximized.wt")

In [34]:
model.load_state_dict(torch.load(f"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-split3/ms-tcn-initial-15-epochs.wt"))

<All keys matched successfully>

In [25]:
print("Calculating Validation Data Accuracy")
correct = 0.0
total = 0.0
for i, item in enumerate(testloader):
    with torch.no_grad():
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

        middle_pred, predictions = model(item_0, src_mask_mse)

        pred = torch.argmax(predictions[-1], dim=1)
        correct += float(torch.sum((pred == item_2) * src_mask).item())
        total += float(torch.sum(src_mask).item())

print(f"Validation:: Epoch {epoch}, Probability Accuracy {correct * 100.0 / total}")

Calculating Validation Data Accuracy
Validation:: Epoch 15, Probability Accuracy 57.770101729343025


In [9]:
import pickle
# pickle.dump(video_id_boundary_frames, open("dump_dir/video_id_boundary_frames_dict.pkl", "wb"))
# pickle.dump(loaded_vidid_selected_frames, open("dump_dir/loaded_vidid_selected_frames_dict.pkl", "wb"))
pickle.dump(boundary_dict, open("dump_dir/chunk_1_video_id_boundary_frames_dict.pkl", "wb"))