In [1]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import os
import flwr as fl
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchsummary import summary
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from enum import Enum
from tqdm import tqdm
import gc
from numba import cuda
import networkx as nx
import random
import time
import enum
import threading
from matplotlib import pyplot as plt
from zod import ZodFrames
from zod import ZodSequences
import zod.constants as constants
from zod.constants import Camera, Lidar, Anonymization, AnnotationProject
import json
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import requests
import cv2
from flask import Flask, request, jsonify
import multiprocessing
from zod.visualization.oxts_on_image import visualize_oxts_on_image
from zod.constants import Camera
from zod.data_classes.calibration import Calibration
from zod.data_classes.oxts import EgoMotion
from zod.utils.polygon_transformations import polygons_to_binary_mask
from zod.utils.geometry import (
    get_points_in_camera_fov,
    project_3d_to_2d_kannala,
    transform_points,
)
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from pprint import pprint
from torch.utils.data import DataLoader
from PIL import Image
from statistics import mean
from ema_pytorch import EMA
import sys

  from zod.data_classes.oxts import EgoMotion


In [4]:
# read unused frames
with open("frames_with_less_than_165m_hp.json") as f: UNUSED_FRAMES = set(json.load(f))

# read the config
with open("../config.json") as f: configs = json.load(f)
print(sys.argv)
if(len(sys.argv) <= 1 or len(sys.argv[1]) > 2): 
    config = configs[-1]
else:
    config = [c for c in configs if c['exp_id'] == int(sys.argv[1])][0]
print(config)

# helper function to read from config
c = lambda a : config[a]

# specify the device
DEVICE = torch.device("cuda" if c('use_gpu') else "cpu")

print(f"PyTorch={torch.__version__}. Pytorch vision={torchvision.__version__}. Flower={fl.__version__}")
print(f"Training will run on: {DEVICE}s")

# path to tensor board persistent folders
DISC = f"exp-{c('exp_id')}_{c('type')}_agent-{c('agent_id')}_{c('model')}_{c('dataset_division')}_{c('loss')}_{c('target_distances')[-1]}m_imgnet_normalized_{c('num_local_epochs')}epochs_lr{c('learning_rate')}_{c('subset_factor')*34000}trainImages_bs{c('batch_size')}_imgSize{c('image_size')}_unfreezed_ema-{c('use_ema')}"
now = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
TB_PATH = f"TensorBoard/{DISC}_{now}"
TB_CENTRALIZED_SUB_PATH = "TensorBoard_Centralized/loss/"
TB_FEDERATED_SUB_PATH = "TensorBoard_Federated/loss/"
TB_SWARM_SUB_PATH = "TensorBoard_Swarm/loss/"
print(DISC)

# global tensorboard writer
writer = SummaryWriter(TB_PATH)

['/home/yasan/miniconda3/envs/zen/lib/python3.9/site-packages/ipykernel_launcher.py', '--ip=127.0.0.1', '--stdin=9016', '--control=9014', '--hb=9013', '--Session.signature_scheme="hmac-sha256"', '--Session.key=b"d9bef1ea-7034-4455-b820-1b63e9089596"', '--shell=9015', '--transport="tcp"', '--iopub=9017', '--f=/home/yasan/.local/share/jupyter/runtime/kernel-v2-4147035uyByFrpxEMe7.json']
{'exp_id': 3, 'type': 'federated', 'agent_id': 0, 'output_size': 51, 'image_size': 256, 'run_pretrained': True, 'batch_size': 8, 'val_factor': 0.1, 'subset_factor': 0.002, 'use_gpu': True, 'model': 'resnet18', 'num_clients': 2, 'num_global_rounds': 2, 'num_local_epochs': 1, 'print_debug_data': True, 'num_workers': 0, 'prefetch_factor': None, 'frames_image_mean': [0.337, 0.345, 0.367], 'frames_image_std': [0.16, 0.18, 0.214], 'dataset_root': '/mnt/ZOD', 'zenseact_dataset_root': '/staging/dataset_donation/round_2', 'checkpoint_path': None, 'start_from_checkpoint': False, 'use_ema': True, 'dataset_division':

In [5]:

def get_ground_truth(zod_frames, frame_id):
    # get frame
    zod_frame = zod_frames[frame_id]

    # extract oxts
    oxts = zod_frame.oxts

    # get timestamp
    key_timestamp = zod_frame.info.keyframe_time.timestamp()

    try:
        # get posses associated with frame timestamp
        current_pose = oxts.get_poses(key_timestamp)

        # transform poses
        all_poses = oxts.poses[oxts.timestamps>=key_timestamp]
        transformed_poses = np.linalg.pinv(current_pose) @ all_poses

        # get translations
        translations = transformed_poses[:, :3, 3]

        # calculate acc diff distance
        distances = np.linalg.norm(np.diff(translations, axis=0), axis=1)
        accumulated_distances = np.cumsum(distances).astype(int).tolist()

        # get the poses that each have a point having a distance from TARGET_DISTANCES
        pose_idx = [accumulated_distances.index(i) for i in c('target_distances')]
        used_poses = transformed_poses[pose_idx]

    except:
        print("detected invalid frame: ", frame_id)
        return np.array([])

    #print(used_poses.shape)
    points = used_poses[:, :3, -1]
    return points.flatten()
   

def save_dataset_tb_plot(tb_path, sample_distribution, subtitle, seed):
    plt.bar(list(range(1, len(sample_distribution) + 1)), sample_distribution)
    plt.xlabel("Partitions")
    plt.ylabel("Samples")
    plt.suptitle("Distribution of samples")
    plt.title("%s, seed: %s" % (subtitle, seed)),

    """report to tensor board"""
    writer.add_figure("sample_distribution/%s" % (subtitle), plt.gcf(), global_step=0)


def reshape_ground_truth(label, output_size=c('output_size')):
    return label.reshape(((c('output_size') // 3), 3))

def visualize_HP_on_image(zod_frames, frame_id, preds=None, showImg=True):
    """Visualize oxts track on image plane."""
    camera=Camera.FRONT
    zod_frame = zod_frames[frame_id]
    image = zod_frame.get_image(Anonymization.DNAT)
    calibs = zod_frame.calibration
    points_gt = get_ground_truth(zod_frames, frame_id)
    preds_row = preds.copy()
    points = reshape_ground_truth(points_gt)
    
    circle_size = 15
    
    # transform point to camera coordinate system
    T_inv = np.linalg.pinv(calibs.get_extrinsics(camera).transform)
    camerapoints = transform_points(points[:, :3], T_inv)

    # filter points that are not in the camera field of view
    points_in_fov = get_points_in_camera_fov(calibs.cameras[camera].field_of_view, camerapoints)
    points_in_fov = points_in_fov[0]

    # project points to image plane
    xy_array = project_3d_to_2d_kannala(
        points_in_fov,
        calibs.cameras[camera].intrinsics[..., :3],
        calibs.cameras[camera].distortion,
    )
    
    ground_truth_color = (19, 80, 41)
    preds_color = (161, 65, 137)
    
    points = []
    for i in range(xy_array.shape[0]):
        x, y = int(xy_array[i, 0]), int(xy_array[i, 1])
        cv2.circle(image, (x,y), circle_size, ground_truth_color, -1)
        points.append([x,y])
    
    """Draw a line in image."""
    def draw_line(image, line, color):
        return cv2.polylines(image.copy(), [np.round(line).astype(np.int32)], isClosed=False, color=color, thickness=20)
    
    image = draw_line(image, points, ground_truth_color)
    
    # transform and draw predictions 
    if(preds is not None):
        preds = reshape_ground_truth(preds)
        predpoints = transform_points(preds[:, :3], T_inv)
        predpoints_in_fov = get_points_in_camera_fov(calibs.cameras[camera].field_of_view, predpoints)
        predpoints_in_fov = predpoints_in_fov[0]
        
        xy_array_preds = project_3d_to_2d_kannala(
            predpoints_in_fov,
            calibs.cameras[camera].intrinsics[..., :3],
            calibs.cameras[camera].distortion,
        )
        preds = []
        for i in range(xy_array_preds.shape[0]):
            x, y = int(xy_array_preds[i, 0]), int(xy_array_preds[i, 1])
            cv2.circle(image, (x,y), circle_size, preds_color, -1)
            preds.append([x,y])
        
        #preds = preds[:(len(preds)//2)]
        image = draw_line(image, preds, preds_color)
        
    #plt.imsave(f'inference_{frame_id}.png', image)
    if(showImg):
        plt.clf()
        plt.axis("off")
        plt.imshow(image)
    return image, points_gt, preds_row
    
def visualize_multiple(zod_frames, frame_ids, model=None):
    if(model):
        images = [visualize_HP_on_image(zod_frames, frame_id, predict(model.to(DEVICE), zod_frames, frame_id), showImg=False) for frame_id in frame_ids]
    else:
        images = [visualize_HP_on_image(zod_frames, frame_id, None, showImg=False) for frame_id in frame_ids]
        
    plt.figure(figsize=(60,60))
    columns = 4
    plt.subplots_adjust(wspace=0, hspace=0)
    for i, image in enumerate(images):
        plt.subplot(len(images) // columns + 1, columns, i + 1)
        plt.gca().set_title(frame_ids[i], fontsize=20)
        plt.imshow(image[0])  
    
def get_transformed_image(zod_frames, frame_id):
    frame = zod_frames[frame_id]
    image_path = frame.info.get_key_camera_frame(Anonymization.DNAT).filepath
    image = np.array(Image.open(image_path).convert("RGB"))
    image = np.array(Image.fromarray(image).resize((c('image_size'), c('image_size')), Image.BILINEAR))
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(size=(c('image_size'), c('image_size')), antialias=True)
        ])
    image = transform(image).unsqueeze(0).to(DEVICE)
    return image

def predict(model, zod_frames, frame_id):
    image = get_transformed_image(zod_frames, frame_id).to(DEVICE)
    outputs = model(image).squeeze(0)
    label = torch.from_numpy(get_ground_truth(zod_frames, frame_id)).to(DEVICE)
    
    loss = torch.nn.L1Loss(label, outputs)

    print(f'frame: {frame_id}, loss: {loss.item()}')
    preds = outputs.cpu().detach().numpy()
    return preds

def load_HP(dataset_root):
    zod_frames = ZodFrames(dataset_root=dataset_root, version='full')
    training_frames_all = zod_frames.get_split(constants.TRAIN)
    validation_frames_all = zod_frames.get_split(constants.VAL)

    return zod_frames, training_frames_all, validation_frames_all

def is_valid(frame_id):
    return frame_id not in UNUSED_FRAMES

def save_to_json(path, data):
    with open(path, 'w') as f:
        f.write(json.dumps(data))
        
def load_from_json(json_path):
    with open(json_path, 'r') as f:
        return json.load(f)

In [10]:
class ZODImporter:
    def __init__(
        self,
        root=c('dataset_root'),
        subset_factor=c('subset_factor'),
        img_size=c('image_size'),
        batch_size=c('batch_size'),
        tb_path=TB_PATH,
        zod_frames=None,
        training_frames=None, 
        validation_frames=None
    ):
        if(zod_frames == None):
            self.zod_frames = ZodFrames(dataset_root=root, version='full')

            self.training_frames_all = self.zod_frames.get_split(constants.TRAIN)
            self.validation_frames_all = self.zod_frames.get_split(constants.VAL)
            
            self.training_frames, self.validation_frames = self.get_train_val_ids(
                self.training_frames_all, 
                self.validation_frames_all, 
                subset_factor)
        else:
            self.zod_frames = zod_frames
            self.training_frames = training_frames
            self.validation_frames = validation_frames
            
        print("length of training_frames subset:", len(self.training_frames))
        print("length of validation_frames subset:", len(self.validation_frames))

        self.img_size = img_size
        self.batch_size = batch_size
        self.tb_path = tb_path
    
    def get_train_val_ids(self, training_frames_all, validation_frames_all, subset_factor):
        if(c('dataset_division') == 'balanced'):
            with open("../balanced_train_ids.txt") as f:
                training_frames_all = json.load(f)
                print(f'balanced sample: {training_frames_all[:5]}')

        training_frames = list(training_frames_all)[: int(len(training_frames_all) * subset_factor)]
        validation_frames = list(validation_frames_all)[: int(len(validation_frames_all) * subset_factor)]

        training_frames = [x for x in tqdm(training_frames) if is_valid(x)]
        validation_frames = [x for x in tqdm(validation_frames) if is_valid(x)]

        return training_frames, validation_frames
        
    def is_valid(self, frame_id):
        return frame_id not in UNUSED_FRAMES
        
    def load_datasets(self, num_clients=c('num_clients')):
        seed = 42
        imagenet_mean=[0.485, 0.456, 0.406]
        imagenet_std=[0.229, 0.224, 0.225]

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(imagenet_mean, imagenet_std),
            transforms.Resize(size=(self.img_size, self.img_size), antialias=True)
        ])

        trainset = ZodDataset(zod_frames=self.zod_frames, frames_id_set=self.training_frames, transform=transform)
        testset = ZodDataset(zod_frames=self.zod_frames, frames_id_set=self.validation_frames, transform=transform)

        # Split training set into `num_clients` partitions to simulate different local datasets
        partition_size = len(trainset) // num_clients

        lengths = [partition_size]
        if num_clients > 1:
            lengths = [partition_size] * (num_clients - 1)
            lengths.append(len(trainset) - sum(lengths))

        datasets = random_split(trainset, lengths, torch.Generator().manual_seed(seed))

        # Split each partition into train/val and create DataLoader
        trainloaders, valloaders = [], []
        lengths_train, lengths_val = [], []
        for ds in datasets:
            len_val = int(len(ds) * c('val_factor'))
            len_train = int(len(ds) - len_val)
            lengths_train.append(len_train)
            lengths_val.append(len_val)
            lengths = [len_train, len_val]
            ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(seed))
            trainloaders.append(DataLoader(ds_train,batch_size=self.batch_size, shuffle=True, num_workers=c('num_workers')))
            valloaders.append(DataLoader(ds_val, batch_size=self.batch_size, num_workers=c('num_workers')))

        len_complete_val = int(len(trainset) * c('val_factor'))
        len_complete_train = int(len(trainset) - len_complete_val)
        train_split, val_split = random_split(
            trainset,
            [len_complete_train, len_complete_val],
            torch.Generator().manual_seed(seed),
        )

        completeTrainloader = DataLoader(
            train_split, batch_size=self.batch_size, num_workers=c('num_workers'), shuffle=True, 
            prefetch_factor=c('prefetch_factor'),
            pin_memory= True)
        
        completeValloader = DataLoader(
            val_split, batch_size=self.batch_size, num_workers=c('num_workers'), shuffle=True,
            prefetch_factor=c('prefetch_factor'),
            pin_memory= True)

        testloader = DataLoader(testset, batch_size=self.batch_size, num_workers=c('num_workers'))

        """report to tensor board"""
        save_dataset_tb_plot(self.tb_path, lengths_train, "training", seed)
        save_dataset_tb_plot(self.tb_path, lengths_val, "validation", seed)
        save_dataset_tb_plot(self.tb_path, [len(testset)], "testing", seed)

        return (
            trainloaders,
            valloaders,
            testloader,
            completeTrainloader,
            completeValloader,
        )

class ZodDataset(Dataset):
    def __init__(
        self,
        zod_frames,
        frames_id_set,
        transform=None,
        target_transform=None,
    ):
        self.zod_frames = zod_frames
        self.frames_id_set = frames_id_set
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.frames_id_set)
    
    def __getitem__(self, idx):
        
        # load frame
        frame_idx = self.frames_id_set[idx]
        frame = self.zod_frames[frame_idx]
        
        # get image
        image_path = frame.info.get_key_camera_frame(Anonymization.DNAT).filepath
        image = np.array(Image.open(image_path).convert("RGB"))
        
        # extract ground truth
        label = get_ground_truth(self.zod_frames, frame_idx)
        
        # create sample
        sample = dict(image=image, label=label)
        
        # resize images
        image = np.array(Image.fromarray(sample["image"]).resize((c('image_size'), c('image_size')), Image.BILINEAR))
        
        # convert to other format HWC -> CHW
        #sample["image"] = np.moveaxis(image, -1, 0)
        sample["label"] = np.expand_dims(label, 0).astype(np.float32)
        
        if(self.transform):
            sample["image"] = self.transform(sample["image"])
        
        return sample

In [7]:
class PT_Model(pl.LightningModule):
    def __init__(self, cid=0) -> None:
        super(PT_Model, self).__init__()
        
        self.model = None
        if(c('model') == 'resnet18'):
            self.model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        elif(c('model') == 'mobile_net'):
            self.model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V1)

        self.change_head_net()
        self.useEma = c('use_ema')
        self.ema = None

        if(self.useEma):
            self.ema = EMA(
                self.model,
                beta=0.995,
                update_after_step=100,
                power=3/4,
                inv_gamma=1.0
            )

        self.is_pretrained = True
        self.loss_fn = torch.nn.L1Loss()
        self.cid = cid
        
        # pytorch imagenet calculated mean/std
        self.mean=[0.485, 0.456, 0.406]
        self.std=[0.229, 0.224, 0.225]
        
        self.epoch_counter = 1

        self.inter_train_outputs = []
        self.inter_train_ema_outputs = []

        self.inter_val_outputs = []
        self.inter_val_ema_outputs = []

        self.inter_test_outputs = []
        self.inter_test_ema_outputs = []

    def forward(self, image):
        label = self.model(image)

        if(self.useEma):
            ema_label = self.ema(image)
            return label, ema_label

        return label, None

    def model_parameters(self):
        return self.model.fc.parameters()

    def change_head_net(self):
        num_ftrs = 0

        if(c('model') == 'resnet18'):
            num_ftrs = self.model.fc.in_features
        elif(c('model') == 'mobile_net'):
            num_ftrs = self.model.classifier[-1].in_features

        head_net = nn.Sequential(
            nn.Linear(num_ftrs, 1024, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(512, c('output_size'), bias=True),
        )

        if(c('model') == 'resnet18'):
            self.model.fc = head_net
        elif(c('model') == 'mobile_net'):
            self.model.classifier[-1] = head_net
            
    def shared_step(self, batch, batch_idx, stage, inter_outputs, ema_inter_outputs):
        image = batch["image"]
        label = batch["label"]

        logits_label, ema_logits_label = self.forward(image)
        logits_label = logits_label.unsqueeze(dim=1)
        loss = self.loss_fn(logits_label, label)

        if(self.useEma):
            ema_logits_label = ema_logits_label.unsqueeze(dim=1)
            ema_loss = self.loss_fn(ema_logits_label, label)

        ema_loss = ema_loss if(self.useEma) else None
        
        inter_outputs.append(loss.item())
        if(self.useEma):
            ema_inter_outputs.append(ema_loss.item())
            if(stage == 'train'):
                self.ema.update()

        if(batch_idx == 1 and stage != 'test'):
            self.updateTB(inter_outputs, stage)
            
            if(self.useEma):
                self.updateTB(ema_inter_outputs, f'{stage}_ema')

            if(stage == 'valid'):
                self.epoch_counter +=1

            print('DEBUG self.inter_train_outputs', self.inter_train_outputs)
            print('DEBUG Epoch loss:', loss.item())
            print('DEBUG Ema epoch loss:', ema_loss.item())
        
        print('DEBUG batch loss:', loss.item())
        print('DEBUG Ema batch loss:', ema_loss.item())

        return loss

    def shared_epoch_end(self, inter_outputs, ema_inter_outputs, stage):
        metrics = {
            f"{stage}_loss": inter_outputs[-1],
            f"{stage}_ema_loss": ema_inter_outputs[-1] if(self.useEma) else None,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx, 'train', self.inter_train_outputs, self.inter_train_ema_outputs)
        return loss

    def on_train_epoch_end(self):
        return self.shared_epoch_end(self.inter_train_outputs, self.inter_train_ema_outputs, 'train')

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx, 'valid', self.inter_val_outputs, self.inter_val_ema_outputs)
        return loss

    def on_validation_epoch_end(self):
        return self.shared_epoch_end(self.inter_val_outputs, self.inter_val_ema_outputs, 'valid')

    def test_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx, 'test', self.inter_test_outputs, self.inter_test_ema_outputs)
        return loss

    def on_test_epoch_end(self):
        return self.shared_epoch_end(self.inter_test_outputs, self.inter_test_ema_outputs, 'test')

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=c('learning_rate'))
    
    def compute_metrics(pred_trajectory, target_trajectory):
        # L1 and L2 distance: matrix of size BSx40x3
        L1_loss = torch.abs(pred_trajectory - target_trajectory)
        L2_loss = torch.pow(pred_trajectory - target_trajectory, 2)

        # BSx40x3 -> BSx3 average over the predicted points
        L1_loss = L1_loss.mean(axis=1)
        L2_loss = L2_loss.mean(axis=1)

        # split into losses for each axis and an avg loss across 3 axes
        # All returned tensors have shape (BS)
        return {
                'L1_loss':   L1_loss.mean(axis=1),
                'L1_loss_x': L1_loss[:, 0],
                'L1_loss_y': L1_loss[:, 1],
                'L1_loss_z': L1_loss[:, 2],
                'L2_loss':   L2_loss.mean(axis=1),
                'L2_loss_x': L2_loss[:, 0],
                'L2_loss_y': L2_loss[:, 1],
                'L2_loss_z': L2_loss[:, 2]}
    
    def get_TB_path(self):
        if(c('type')=='centralized'):
            return TB_CENTRALIZED_SUB_PATH
        if(c('type')=='federated'):
            return f"{TB_FEDERATED_SUB_PATH}{self.cid}/",

    def updateTB(self, inter_outputs, stage):
        epoch_loss = np.mean(inter_outputs) 
        writer.add_scalars(self.get_TB_path(), {stage: epoch_loss},self.epoch_counter)
        inter_outputs = [epoch_loss]

In [8]:
def train(model, train_dataloader, valid_dataloader, nr_epochs=c('num_local_epochs')):
    trainer = get_trainer()

    trainer.fit(
        model, 
        train_dataloaders=train_dataloader, 
        val_dataloaders=valid_dataloader,
    )

    return trainer

def validate(model, valid_dataloader):
    trainer = get_trainer()
    valid_metrics = trainer.validate(model, dataloaders=valid_dataloader, verbose=False)
    pprint(valid_metrics)

def test(model, test_dataloader):
    trainer = get_trainer()
    test_metrics = trainer.test(model, dataloaders=test_dataloader, verbose=False)
    pprint(test_metrics)
    return test_metrics

def get_trainer():
    return pl.Trainer(
        accelerator= 'gpu',
        max_epochs=c('num_local_epochs'),
        devices=[c('gpu_id')],
    )

def net_instance(name):
    print(f"🌻 Created new model - {name} 🌻")
    return PT_Model()

def get_parameters(net, cid):
    print(f"⤺ Get model parameters of client {cid}")
    return [val.cpu().numpy() for _, val in net.model.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray], cid):
    print(f"⤻ Set model parameters of client {cid}")
    params_dict = zip(net.model.state_dict().keys(), parameters)
    state_dict = OrderedDict(
        {
            k: torch.Tensor(v) if v.shape != torch.Size([]) else torch.Tensor([0])
            for k, v in params_dict
        }
    )
    net.model.load_state_dict(state_dict, strict=True)

def save_model(net, name):
    print(f"🔒 Saved the model of client {name} to the disk. 🔒")
    torch.save(net.model.state_dict(), f"{name}.pth")

def load_model(name):
    print(f"🛅 Loaded the model of client {name} from the disk. 🛅")
    net = net_instance(f"{name}")
    net.model.load_state_dict(torch.load(f"{name}.pth"))
    return net

In [11]:
# get loaders
trainloaders, valloaders, testloader, completeTrainloader, completeValloader = ZODImporter().load_datasets(num_clients=1)

Loading infos: 0it [00:00, ?it/s]

balanced sample: ['003298', '074927', '038839', '021686', '088515']


100%|██████████| 70/70 [00:00<00:00, 199457.39it/s]
100%|██████████| 20/20 [00:00<00:00, 7936.99it/s]

length of training_frames subset: 70
length of validation_frames subset: 14





In [12]:
# create model
model = PT_Model.load_from_checkpoint(c('checkpoint_path')) if(c('start_from_checkpoint')) else PT_Model()

# train supervised
train(model, completeTrainloader, completeValloader, nr_epochs=c('num_local_epochs'))

# validate 
validate(model, completeValloader)

# test
test(model, testloader)

writer.close()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type   | Params
-----------------------------------
0 | model   | ResNet | 12.3 M
1 | ema     | EMA    | 24.5 M
2 | loss_fn | L1Loss | 0     
-----------------------------------
12.3 M    Trainable params
12.3 M    Non-trainable params
24.5 M    Total params
98.022    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


DEBUG batch loss: 16.392702102661133
DEBUG Ema batch loss: 16.392702102661133


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

DEBUG batch loss: 16.660438537597656
DEBUG Ema batch loss: 16.660438537597656


AttributeError: 'tuple' object has no attribute 'replace'