In [None]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn as nn
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torchinfo import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt 
from utils.timeseriesdataset import TimeSeriesDataset
from utils.pad_batch import pad_batch, LABEL_PADDING_VALUE
from models.ClassificationModel import ClassificationModel, NUM_CLASSES
import pickle 
from pathlib import Path
from utils.load_data import load_data
from sklearn.metrics import confusion_matrix
import numpy as np

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 35
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 2e-6

torch.cuda.empty_cache()
print('The model is running on:', DEVICE) 

# Create DataLoaders

In [None]:
train_instances = []
val_instances = []
test_instances = []

train_files = list(Path("../data/simulated_tracks").glob("*/train_instances.pkl"))
val_files = list(Path("../data/simulated_tracks").glob("*/val_instances.pkl"))
test_files = list(Path("../data/simulated_tracks").glob("*/test_instances.pkl"))

for file in train_files:
    with open(file, "rb") as f:
        train_instances += pickle.load(f)

for file in val_files:
    with open(file, "rb") as f:
        val_instances += pickle.load(f)

for file in test_files:
    with open(file, "rb") as f:
        test_instances += pickle.load(f)

print("Train data: ", len(train_instances), "Test data: ", len(test_instances), "Val data: ", len(val_instances))

In [None]:
conc_train = ConcatDataset(train_instances)
conc_val = ConcatDataset(val_instances)
conc_test = ConcatDataset(test_instances)

train_loader = DataLoader(conc_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
test_loader = DataLoader(conc_test, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
val_loader = DataLoader(conc_val, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)

print("DataLoader Sizes:", len(train_loader), len(test_loader), len(val_loader))

# DATA_PATH = "../data/simulated_tracks"
# filepaths = list(Path(DATA_PATH).rglob('*.parquet'))
# random.shuffle(filepaths)

# print("Number of files found:", len(filepaths))

# train_data = []
# test_data = []
# val_data = []

# train_data = [TimeSeriesDataset(filepath, augment=True) for filepath in filepaths[:int(len(filepaths)*0.7)]]
# test_data = [TimeSeriesDataset(filepath, augment=False) for filepath in filepaths[int(len(filepaths)*0.7):int(len(filepaths)*0.85)]]
# val_data = [TimeSeriesDataset(filepath, augment=False) for filepath in filepaths[int(len(filepaths)*0.85):]]

In [None]:
train_data, val_data, test_data = load_data() 

print("Train data: ", len(train_data))
print("Val data: ", len(val_data))
print("Test dataL", len(test_data))

conc_train = ConcatDataset(train_data)
conc_val = ConcatDataset(val_data)
conc_test = ConcatDataset(test_data)

training_loader = DataLoader(conc_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
val_loader = DataLoader(conc_val, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
test_loader = DataLoader(conc_test, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)

print("Data", len(training_loader), len(val_loader), len(test_loader))

# Class Weights 
State labels are imbalanced

In [4]:
# class_counts = torch.zeros(NUM_CLASSES, dtype=torch.long, device=DEVICE)
# progress_bar = tqdm(total=len(training_loader), desc='Weights Finder', position=0)

# for _, _, _, state_labels in training_loader:
    
#     states = state_labels.to(DEVICE).flatten()
#     states = states[states != LABEL_PADDING_VALUE]
#     class_counts += torch.bincount(states.long(), minlength=class_counts.numel())
#     progress_bar.update()

# total_samples = class_counts.sum().item()

# weights = class_counts.float().reciprocal() * total_samples
# normalized_weights = torch.tensor(weights / weights.sum(), device=DEVICE)

# progress_bar.close()
# print("Class Weights (0,1,2,3):",normalized_weights)

# Class Weights (0,1,2,3): tensor([0.1835, 0.1536, 0.0173, 0.6456], device='cuda:1')
normalized_weights = torch.tensor([0.1835, 0.1536, 0.0173, 0.6456], device=DEVICE)

# Model
Load the model, optimizer, scheduler, loss <br>
Focal Loss is commented but can be used for training.

In [7]:
# # If using Focal Loss remove the log_softmax from the model as cross_entropy already applies it
# class FocalLoss(nn.Module):
#     def __init__(self, gamma=3):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma

#     def forward(self, output, targets):
        
#         ce_loss = F.cross_entropy(output, targets, reduction='none', ignore_index=LABEL_PADDING_VALUE)
#         mask = (targets != LABEL_PADDING_VALUE).float()
#         pt = torch.exp(-ce_loss)
#         focal_weight = (1 - pt) ** self.gamma
#         loss = focal_weight * ce_loss * mask

#         return loss.sum() / mask.sum()

In [None]:
model = ClassificationModel().to(DEVICE)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('models/checkpoints/state_runs/runs_{}'.format(timestamp))
model_directory = os.path.join('models/checkpoints/state_model', 'model_{}'.format(timestamp))
print(summary(model, input_size=(BATCH_SIZE, 200, 10)))

best_val_loss = float("inf")
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

# state_loss_fn = FocalLoss()
state_loss_fn = nn.NLLLoss(reduction='mean', weight=normalized_weights.to(DEVICE), ignore_index=LABEL_PADDING_VALUE)
state_loss_fn_inference = nn.NLLLoss(reduction='mean', ignore_index=LABEL_PADDING_VALUE)

# Training Functions

In [10]:
def train_one_epoch(model, optimizer, dataloader):
    model.train()
    running_loss = 0
    runs = 0

    for inputs, _, _ , state_labels in dataloader:

        inputs, state_labels = inputs.to(DEVICE), state_labels.to(DEVICE)
        outputs = model(inputs)

        state_log_probs  = outputs.permute(0, 2, 1) # output is shape 32,4,500 
        loss_state = state_loss_fn(state_log_probs, state_labels.long()) 

        optimizer.zero_grad()
        loss_state.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        running_loss += loss_state.item()
        runs += 1

        progress_bar.update()

    return running_loss/runs

def evaluate_model(model, dataloader):
    model.eval()
    
    running_val_total = 0.0
    val_runs = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, _, _ , state_labels in dataloader:
            
            inputs, state_labels = inputs.to(DEVICE), state_labels.to(DEVICE)            
            outputs = model(inputs)  
            state_log_probs = outputs.permute(0, 2, 1)  # state logs is 32,NUM_CLASSES,200

            loss_state = state_loss_fn_inference(state_log_probs, state_labels.long())
            predictions = torch.argmax(outputs, dim=-1)

            mask = (state_labels != LABEL_PADDING_VALUE)

            all_preds.append(torch.masked_select(predictions, mask))
            all_labels.append(torch.masked_select(state_labels, mask))

            running_val_total += loss_state.item()
            val_runs += 1
            progress_bar.update()
    
    all_preds = torch.cat(all_preds).cpu().numpy()
    all_labels = torch.cat(all_labels).cpu().numpy()
    conf_matrix = 100*np.round(confusion_matrix(all_labels, all_preds, labels=[0, 1, 2, 3], normalize='true'),2)    

    return running_val_total / val_runs, conf_matrix

# Train

In [None]:
os.makedirs(model_directory, exist_ok=True)

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    progress_bar = tqdm(total=len(training_loader), desc='Training', position=0)

    avg_training_loss = train_one_epoch(model, optimizer, training_loader)
    val_total_loss, matrix  = evaluate_model(model, val_loader)

    print(f'Training LOSS: State {avg_training_loss}\n'
          f'Validation LOSS: State {val_total_loss} \nState Conf \n{matrix} \n')
          
    writer.add_scalars('Losses', {
        'Training Total': avg_training_loss,
        'Validation Total': val_total_loss, 
        }, epoch + 1)

    writer.flush()
    
    if val_total_loss < best_val_loss:
        best_val_loss = val_total_loss
        
    model_path = os.path.join(model_directory, f'model_{epoch + 1}_{int(round(val_total_loss, 2)*100)}')
    torch.save(model.state_dict(), model_path)

    scheduler.step(val_total_loss)
    
progress_bar.close()
writer.close()

In [None]:
print("Best Validation Loss:", best_val_loss)
print("Best Model Path", model_path)

# model 19 is best with the weights in the validation loss

# without the weights in the validation loss model 15:
# Training: 100%|██████████| 105000/105000 [52:46<00:00, 33.16it/s]
# Training: 100%|██████████| 105000/105000 [40:21<00:00, 41.51it/s]
# Training LOSS: State 0.23815074014839316
# Validation LOSS: State 0.3648213621613623 
# State Conf 
# [[99.  0.  0.  0.]
#  [ 1. 92.  5.  1.]
#  [ 1.  8. 82.  9.]
#  [ 0.  1.  6. 93.]] 

# Testing

In [None]:
model_path = "/home/haidiri/Desktop/AnDiChallenge2024/models/checkpoints/state_model/model_20241001_115850_weighed_on_val/model_25_23"
model.load_state_dict(torch.load(model_path))
model.eval()

running_test_total = 0.0
test_runs = 0.0

predictions_list = []
ground_truth_list = []

all_preds = []
all_labels = []

progress_bar = tqdm(total=len(test_loader), desc='Testing', position=0)

with torch.no_grad():
    for inputs, _, _, state_labels in test_loader:
        
        inputs, state_labels = inputs.to(DEVICE), state_labels.to(DEVICE)

        mask = (state_labels != LABEL_PADDING_VALUE)
        outputs = model(inputs).squeeze(-1)

        state_log_probs = outputs.permute(0, 2, 1)  # state logs is 32,3,200

        loss_state = state_loss_fn_inference(state_log_probs, state_labels.long())
        predictions = torch.argmax(outputs, dim=-1)

        all_preds.append(torch.masked_select(predictions, mask))
        all_labels.append(torch.masked_select(state_labels, mask))
        predictions_list.extend(outputs.cpu().numpy())
        ground_truth_list.extend(predictions.cpu().numpy())

        running_test_total += loss_state.item()
        test_runs += 1
        
        progress_bar.update()

all_preds = torch.cat(all_preds).cpu().numpy()
all_labels = torch.cat(all_labels).cpu().numpy()
conf_matrix = 100*np.round(confusion_matrix(all_labels, all_preds, labels=[0, 1, 2, 3], normalize='true'),2)    
# Calculate average losses
avg_test_loss = running_test_total / test_runs
print(f'Average test loss: {avg_test_loss}')
print(conf_matrix)
progress_bar.close()

In [None]:
print((100+91+82+94)/4)

# Plot Predictions

In [None]:
INDEX = 53

padding_starts = (ground_truth[INDEX] == LABEL_PADDING_VALUE).argmax() 

if padding_starts == 0:
    padding_starts = 200

pred_alpha = predictions[INDEX][:padding_starts]
true_alpha = ground_truth[INDEX][:padding_starts]

plt.scatter([i for i in range(len(pred_alpha))], pred_alpha, color="red")
plt.scatter([i for i in range(len(true_alpha))], true_alpha, color="blue")
plt.title("Alpha")
plt.show()