In [1]:
import os
import h5py
import torch
import pickle
import numpy as np
import pandas as pd

from src.strand.functions import *

In [2]:
def ckpt_to_csv_save(save_dir, ckpt):
    os.makedirs(save_dir, exist_ok=True)
    
    t = logit_to_distribution(ckpt['state_dict']['_t']).numpy()
    r = logit_to_distribution(ckpt['state_dict']['_r']).numpy()
    e = logit_to_distribution(ckpt['state_dict']['_e']).numpy()
    n = logit_to_distribution(ckpt['state_dict']['_n']).numpy()
    c = logit_to_distribution(ckpt['state_dict']['_c']).numpy()

    [[_cl, _cg], [_tl, _tg]] = ckpt['state_dict']['_T0']

    cl = logit_to_distribution(_cl).numpy()
    cg = logit_to_distribution(_cg).numpy()
    tl = logit_to_distribution(_tl).numpy()
    tg = logit_to_distribution(_tg).numpy()
    
    T = torch.stack([cl, cg, tl, tg], dim=0)

    cl_table = pd.DataFrame(cl, columns = [f'rank_{i}' for i in range(1, rank+1)], index = [f'tri_{i}' for i in range(1, 97)])
    cg_table = pd.DataFrame(cg, columns = [f'rank_{i}' for i in range(1, rank+1)], index = [f'tri_{i}' for i in range(1, 97)])
    tl_table = pd.DataFrame(tl, columns = [f'rank_{i}' for i in range(1, rank+1)], index = [f'tri_{i}' for i in range(1, 97)])
    tg_table = pd.DataFrame(tg, columns = [f'rank_{i}' for i in range(1, rank+1)], index = [f'tri_{i}' for i in range(1, 97)])


    cl_table.to_csv(os.path.join(save_dir, "cl.csv"))
    cg_table.to_csv(os.path.join(save_dir, "cg.csv"))
    tl_table.to_csv(os.path.join(save_dir, "tl.csv"))
    tg_table.to_csv(os.path.join(save_dir, "tg.csv"))

    pd.DataFrame(
        t, columns = [f'rank_{i}' for i in range(1, rank+1)], index = ['plus', 'minus']
    ).to_csv(os.path.join(save_dir, "t.csv"))

    pd.DataFrame(
        r, columns = [f'rank_{i}' for i in range(1, rank+1)], index = ['plus', 'minus']
    ).to_csv(os.path.join(save_dir, "r.csv"))

    pd.DataFrame(
        e, columns = [f'rank_{i}' for i in range(1, rank+1)]
    ).to_csv(os.path.join(save_dir, "e.csv"))

    pd.DataFrame(
        n, columns = [f'rank_{i}' for i in range(1, rank+1)]
    ).to_csv(os.path.join(save_dir, "n.csv"))

    pd.DataFrame(
        c, columns = [f'rank_{i}' for i in range(1, rank+1)]
    ).to_csv(os.path.join(save_dir, "c.csv"))

    pd.DataFrame(
        logit_to_distribution(ckpt['state_dict']['lamb']).numpy(), 
        columns = [f'sample_{i}' for i in range(1, ckpt['state_dict']['lamb'].shape[1]+1)], 
        index = [f'rank_{i}' for i in range(1, rank+1)]
    ).to_csv(os.path.join(save_dir, "theta.csv"))

    pd.DataFrame(
        ckpt['state_dict']['Sigma_mat'].numpy(), 
        columns = [f'rank_{i}' for i in range(1, rank)], 
        index = [f'rank_{i}' for i in range(1, rank)]
    ).to_csv(os.path.join(save_dir, "Sigma_matrix.csv"))

    pd.DataFrame(
        ckpt['state_dict']['Xi'].numpy(), 
        columns = [f'feature_{i}' for i in range(1, ckpt['state_dict']['Xi'].shape[1]+1)], 
        index = [f'rank_{i}' for i in range(1, rank)]
    ).to_csv(os.path.join(save_dir, "Gamma.csv"))


    for r in range(1, rank):
        zeta_r = ckpt['state_dict']['zeta'][r-1].numpy()
        pd.DataFrame(
            zeta_r, 
            columns = [f'feature_{i}' for i in range(1, zeta_r.shape[1]+1)], 
            index = [f'rank_{i}' for i in range(1, zeta_r.shape[1]+1)]
        ).to_csv(os.path.join(save_dir, f"zeta_rank_{r}.csv"))
        
    pd.Series(
        ckpt['state_dict']['sigma'].numpy(), 
        index = [f'rank_{i}' for i in range(1, rank)]
    ).to_csv(os.path.join(save_dir, f"sigma.csv"))
    
    os.makedirs(os.path.join(save_dir, 'Delta'), exist_ok=True)
    
    Delta = np.transpose(ckpt['state_dict']['Delta'].numpy(), (0, 1, 2))
    for n in range(Delta.shape[0]):
        pd.DataFrame(
            Delta[n], 
            columns = [f'rank_{i}' for i in range(1, rank)], 
            index = [f'rank_{i}' for i in range(1, rank)]
        ).to_csv(os.path.join(save_dir, f"Delta/Delta_{n}.csv"))
        
def ckpt_to_hdf5_save(save_dir, ckpt):
    
    t = logit_to_distribution(ckpt['state_dict']['_t']).tolist()
    r = logit_to_distribution(ckpt['state_dict']['_r']).tolist()
    e = logit_to_distribution(ckpt['state_dict']['_e']).tolist()
    n = logit_to_distribution(ckpt['state_dict']['_n']).tolist()
    c = logit_to_distribution(ckpt['state_dict']['_c']).tolist()

    [[_cl, _cg], [_tl, _tg]] = ckpt['state_dict']['_T0']

    cl = logit_to_distribution(_cl)
    cg = logit_to_distribution(_cg)
    tl = logit_to_distribution(_tl)
    tg = logit_to_distribution(_tg)

    T = torch.stack([cl, cg, tl, tg]).reshape(2, 2, *cl.shape).tolist()
    theta = logit_to_distribution(ckpt['state_dict']['lamb']).tolist()
    Sigma_matrix = ckpt['state_dict']['Sigma_mat'].tolist()
    Gamma = ckpt['state_dict']['Xi'].tolist()
    zeta = ckpt['state_dict']['zeta'].tolist()
    sigma = ckpt['state_dict']['sigma'].tolist()
    Delta = ckpt['state_dict']['Delta'].permute(0, 1, 2).tolist()
    
    with h5py.File(save_dir, 'w') as hf:
        hf['t'] = t
        hf['r'] = r
        hf['e'] = e
        hf['n'] = n
        hf['c'] = c

        hf['T'] = T
        hf['theta'] = theta
        hf['Sigma_matrix'] = Sigma_matrix
        hf['Gamma'] = Gamma
        hf['zeta'] = zeta
        hf['sigma'] = sigma
        hf['Delta'] = Delta
    

# PCAWG

In [9]:
rank=18
save_dir = f'result/pcawg/rank_{rank}_random.hdf5'

ckpt = torch.load(f"checkpoints/ts/random_init/rank_{rank}_3.ckpt", map_location=torch.device('cpu'))
ckpt_to_hdf5_save(save_dir, ckpt)

In [4]:
import os

In [7]:
os.listdir("checkpoints/ts/random_init")

['rank_17_2.ckpt', 'rank_18_3.ckpt']

In [5]:
logit_to_distribution(ckpt['state_dict']['_t'])

tensor([[0.4994, 0.4498, 0.5008, 0.5323, 0.5491, 0.5191, 0.5185, 0.5343, 0.4810,
         0.5256, 0.5157, 0.3381, 0.5239, 0.5119, 0.3550, 0.4070, 0.7063],
        [0.5006, 0.5502, 0.4992, 0.4677, 0.4509, 0.4809, 0.4815, 0.4657, 0.5190,
         0.4744, 0.4843, 0.6619, 0.4761, 0.4881, 0.6450, 0.5930, 0.2937]])

# Liver Sanger

In [5]:
rank=5
save_dir = f'result/liver_sanger/rank_{rank}.hdf5'

ckpt = torch.load(f"checkpoints/liver_sanger/rank_{rank}.ckpt", map_location=torch.device('cpu'))
ckpt_to_hdf5_save(save_dir, ckpt)

In [70]:
t = logit_to_distribution(ckpt['state_dict']['_t']).tolist()
r = logit_to_distribution(ckpt['state_dict']['_r']).tolist()

In [71]:
t

[[0.5403376221656799,
  0.51839679479599,
  0.5131873488426208,
  0.5176110863685608,
  0.5426681041717529,
  0.3091532289981842,
  0.355198472738266,
  0.415740042924881,
  0.5072880387306213,
  0.2266566902399063,
  0.41585785150527954,
  0.7611361742019653,
  0.5728378891944885,
  0.5064380764961243,
  0.6342465281486511,
  0.6903036832809448,
  0.5654654502868652,
  0.41797691583633423,
  0.48642638325691223,
  0.5338994264602661],
 [0.45966237783432007,
  0.4816032350063324,
  0.48681262135505676,
  0.4823889136314392,
  0.45733192563056946,
  0.6908467411994934,
  0.6448014974594116,
  0.5842599868774414,
  0.4927119314670563,
  0.7733433246612549,
  0.5841421484947205,
  0.23886382579803467,
  0.42716214060783386,
  0.49356192350387573,
  0.3657534718513489,
  0.3096962869167328,
  0.43453454971313477,
  0.5820230841636658,
  0.5135735869407654,
  0.4661006033420563]]

In [72]:
r

[[0.31347179412841797,
  0.4840059280395508,
  0.414541631937027,
  0.5943647623062134,
  0.4066096246242523,
  0.7962986826896667,
  0.5406166315078735,
  0.7330820560455322,
  0.4738008677959442,
  0.826583981513977,
  0.5489433407783508,
  0.4846862554550171,
  0.3440607190132141,
  0.3364056348800659,
  0.3707752525806427,
  0.547705352306366,
  0.5755954384803772,
  0.563076376914978,
  0.5055904984474182,
  0.4124873876571655],
 [0.686528205871582,
  0.5159940719604492,
  0.5854583978652954,
  0.40563520789146423,
  0.5933904051780701,
  0.20370131731033325,
  0.45938339829444885,
  0.2669179141521454,
  0.5261991620063782,
  0.17341600358486176,
  0.4510566294193268,
  0.5153137445449829,
  0.6559392809867859,
  0.6635943651199341,
  0.6292247772216797,
  0.45229464769363403,
  0.4244045615196228,
  0.43692365288734436,
  0.4944094717502594,
  0.5875126123428345]]

# Alzheimer

In [23]:
save_dir = 'data.h5'

In [48]:
rank=2
save_dir = f'rrank_{rank}.h5'

ckpt = torch.load(f"checkpoints/alz/rank_{rank}.ckpt", map_location=torch.device('cpu'))
t = ckpt_to_hdf5_save(save_dir, ckpt)

IsADirectoryError: [Errno 21] Unable to create file (unable to open file: name = 'rrank_2.h5', errno = 21, error message = 'Is a directory', flags = 13, o_flags = 602)

In [39]:
t

[[0.5392181341335761, 0.4732204995946458],
 [0.46078186586642395, 0.5267795004053542]]

In [45]:
with h5py.File('asdfsa.h5', 'w') as f:
    f['t'] = t

In [44]:
f.close()

In [20]:
rank=2
save_dir = f'result/alz/rank_{rank}'

ckpt = torch.load(f"checkpoints/alz/rank_{rank}.ckpt", map_location=torch.device('cpu'))
ckpt_to_csv_save(save_dir, ckpt)
