In [9]:
def qsgd(x, d):
    """quantize the tensor x in d level on the absolute value coef wise"""
    norm = torch.norm(x, 'fro')
    level_float = d * torch.abs(x) / norm
    previous_level = torch.floor(level_float)
    is_next_level = torch.rand(x.shape) < (level_float - previous_level)
    new_level = previous_level + is_next_level
    return torch.sign(x) * norm * new_level / d

In [10]:
from collections import OrderedDict, namedtuple
from itertools import product
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
import matplotlib.pyplot as plt
from utils.run_manager import RunBuilder
from models import g_step, DeepGCCA
from synth_data import create_synthData

class RunBuilder():
    @staticmethod
    def get_runs(params):

        Run = namedtuple('Run', params.keys())

        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))

        return runs

if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')


params = OrderedDict(
    lr = [0.001],
    batch_size = [1000],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    manual_seed = [1265],
    loss_func = [nn.MSELoss],
    inner_epochs = [50],
    quant = [True], 
    random_compress = [True],
    n_bits = [2]
)

layer_sizes_list = 3*[[128, 64, 2]]
input_size_list = 3*[2]


run_count = 0
models = []


run_data = []

data_load_time = 0
forward_time = 0


for run in RunBuilder.get_runs(params):
#     torch.cuda.set_device(run.device)
    Nlevels = 2**(run.n_bits-1)-1
    run_count += 1
    device = torch.device(run.device)
    
    dgcca = DeepGCCA(layer_sizes_list, input_size_list)
    dgcca = dgcca.to(device)
    
    train_views, classes = create_synthData(N=10000)
    val_views, classes = create_synthData(N=200)
    suffler = torch.randperm(10000)
    
    train_views = [view[suffler].to(device) for view in train_views]
    val_views = [view.to(device) for view in val_views]
    
    optimizer = torch.optim.Adam(dgcca.parameters(), lr=run.lr)
    num_batches = len(train_views[0])//run.batch_size
    
    criterion = run.loss_func()
    num_val_batches = len(val_views[0])//run.batch_size
    
    # init G
    dgcca.eval()
    out = dgcca(train_views)
    out = torch.stack(out)  
    G = g_step(out.clone().detach())  
    M_serv = out.clone()
    M_diff = out.clone()
    dgcca.train()
    
    for epoch in range(50):
        total_recons_loss = 0
        total_val_loss = 0
        batch_count = 0
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['val_fidelity'] = total_recons_loss/num_batches
        
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device

        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        
        for j in range(run.inner_epochs):
            for i in range(num_batches):
                optimizer.zero_grad()
                batch = []
                
                # SGD
                batch = [view[(i*run.batch_size):((i+1)*run.batch_size), :] for view in train_views]            
                target = G[(i*run.batch_size):((i+1)*run.batch_size), :]

                # full gradient
#                 batch = train_views
#                 target = G

                out = dgcca(batch)
                out = torch.stack(out)  
                
                loss = 1/2*torch.norm(out-target)/target.shape[0]
                
#                 print(loss.item())
                
                loss.backward()
                optimizer.step()
                
                total_recons_loss += loss.item()
                
        ## Update G
        dgcca.eval()
        out = dgcca(train_views)
        out = torch.stack(out)
        
        if run.quant:
            for i in range(len(train_views)):
                M_diff[i] = out[i] - M_serv[i]
                max_val = M_diff[i].abs().max()
                
                if run.random_compress:
                    M_quant = qsgd(M_diff[i], Nlevels)
                else:
                    M_quant = ((Nlevels/max_val)*M_diff[i]).round()*(max_val/Nlevels)
                    
                M_serv[i] += M_quant
            G = g_step(M_serv.clone().detach())          
        else:
            G = g_step(out.clone().detach())  
            
        
        # validation loss
        out_val = dgcca(val_views)
        out_val = torch.stack(out_val)
        G_val = g_step(out_val.clone().detach())
        loss_val = 1/2*torch.norm(out_val - G_val)/G_val.shape[0]
        total_val_loss = loss_val.item()

        dgcca.train()
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/(num_batches*run.inner_epochs)
        results['val_fidelity'] = total_val_loss
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        
        run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        try:
            clear_output(wait=True)
            display(df2)
        except:
            pass
        torch.save(dgcca, 'trained_models/dgcca_federated_sgd_random_compressor.model')

starting


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [8]:
M_diff[0].shape

torch.Size([10000, 2])

In [None]:
2**(2-1)-1