In [1]:
import sys

sys.path.append("../src/")
import pandas as pd
from src.networks.pathology import CoxResNet
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from src.downsample_hfaccl.hugg_downsample import setup_data
from datetime import datetime
from accelerate import Accelerator
import os

import matplotlib.pyplot as plt
import torch

from src.nanorcc.parse import get_rcc_data
from src.nanorcc.preprocess import (
    CodeClassGeneSelector,
    FunctionGeneSelector,
    Normalize,
)
from src.nanorcc.quality_control import QualityControl

from src.datasets.rcc_dataset import RCCDataset, rcc_to_csv


2023-04-06 18:01:09.038079: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
config: dict = {
    "lr": 1e-2,
    "num_epochs": 500,
    "seed": 42,
    "batch_size": 3,
    "target_level": 3,
    "num_workers": 10,
    "base_dir": "/data2/projects/DigiStrudMed_sklein/",
    "overfit": 1.0,
    "experiment_name": "bla",
    "data_split": [0.75, 0.20, 0.05],
    "annos_of_interest": [
        "Tissue",
        "Tumor_vital",
        "Angioinvasion",
        "Tumor_necrosis",
        "Tumor_regression",
    ],
    "grad_accum_steps": 12,
    "date": datetime.now().strftime("%Y-%m-%d"),
}
level: int = int(config["target_level"])

config["cache_path"] = (
    config["base_dir"]
    + f"downsampled_datasets/cached_DownsampleDataset_level_{level}.json"
)


In [3]:
normalized_df = rcc_to_csv(
    "/data2/projects/DigiStrudMed_sklein/DigiStrucMed_Braesen/NanoString_RCC/",
    "/data2/projects/DigiStrudMed_sklein/DigiStrucMed_Braesen/",
)
normalized_df.shape

(24, 770)

In [4]:
normalized_df

Unnamed: 0_level_0,A2M,ACVR1C,ADAM12,ADGRE1,ADM,ADORA2A,AKT1,ALDOA,ALDOC,ANGPT1,...,PUM1,SDHA,SF3A1,STK11IP,TBC1D10B,TBP,TFRC,TLK2,TMUB2,UBB
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
172,4355.544835,22.23879,61.900325,4.107802,2126.566542,75.498566,1642.69581,17638.759652,809.803565,69.832632,...,644.358303,3519.25303,919.722678,111.760541,273.806243,305.535471,1064.770579,263.607562,458.51568,6013.397017
165,18651.870674,26.268632,472.454667,23.222993,2249.584712,108.50087,2580.036486,25696.432448,454.180836,140.480074,...,1351.121365,2375.978709,1034.374964,108.50087,405.450621,662.807071,1902.381927,512.047967,801.383622,9493.635805
136,2204.98393,6.860665,6.183069,-0.592897,1158.097233,55.647618,550.293112,5761.010716,381.571567,41.41809,...,260.281781,570.621009,270.44573,24.478176,65.811567,128.150451,189.811738,82.073884,158.642296,2713.181356
162,15533.582084,18.800393,429.69557,111.833263,4397.547494,265.337499,3247.040996,24742.285709,1014.252106,327.359413,...,1718.200826,4214.582849,1420.49564,149.046411,511.874606,693.288703,1192.565108,594.053641,786.321574,11553.325778
129,51.008248,1.100897,5.590832,-0.626001,6.281591,0.928208,20.269464,762.317505,13.879942,0.928208,...,5.936212,9.217318,12.498424,0.582828,2.137036,1.964346,11.980354,5.072763,5.245452,57.22508
167,13889.688971,37.017234,1364.15364,61.146987,2934.781187,578.839866,4349.662149,33809.896736,657.809966,173.021296,...,1706.357407,2392.958556,2627.675242,142.310701,800.394869,896.913881,2673.741134,905.688336,809.169325,24522.135511
151,3203.758798,11.082103,40.978011,1.804063,857.960988,42.524351,674.461972,3454.265881,150.768152,59.534091,...,314.164747,292.515987,239.42498,16.752017,73.966598,99.738931,199.220139,89.429998,116.748672,1936.275428
188,21172.992598,20.073612,74.225682,76.092995,5254.151289,288.96665,3937.69579,31510.436057,3063.793417,421.545856,...,1926.599944,4047.867243,2027.434833,154.520131,645.623388,1248.765411,1463.506378,910.7818,1370.14074,27697.383394
163,636.164863,3.44236,93.111632,-1.93108,429.287442,11.502519,167.668105,1548.306222,22.249398,5.793239,...,61.206835,117.62795,91.096592,6.800759,31.317077,48.780756,66.916114,23.592758,41.056436,682.846619
119,3646.718895,10.008919,184.041546,6.7407,930.829441,50.044594,916.122458,7053.836532,322.123772,59.849249,...,422.621486,1774.846831,307.416789,29.618229,86.81205,203.650856,327.843154,137.469435,215.089621,4051.977971


In [5]:
normalized_df.loc[172]

A2M       4355.544835
ACVR1C      22.238790
ADAM12      61.900325
ADGRE1       4.107802
ADM       2126.566542
             ...     
TBP        305.535471
TFRC      1064.770579
TLK2       263.607562
TMUB2      458.515680
UBB       6013.397017
Name: 172, Length: 770, dtype: float64

In [6]:
dataset = RCCDataset(
    "/data2/projects/DigiStrudMed_sklein/DigiStrucMed_Braesen/NanoString_RCC/",
    config["base_dir"] + "survival_status.csv",
    sparse=177,
)


In [7]:
iter = iter(dataset)

batch = next(iter)
batch[0][-4:]


tensor([1064.7706,  263.6076,  458.5157, 6013.3970])

In [8]:
batch[2]

tensor(0.)

In [9]:
survival_data = config["base_dir"] + "survival_status.csv"
tabular: pd.DataFrame = pd.read_csv(survival_data)
# tabular.index = tabular["case"]  #
tabular.columns

Index(['case', 'surv_days', 'death', 'uncensored'], dtype='object')

In [10]:
tabular


Unnamed: 0,case,surv_days,death,uncensored
0,1,218.0,1.0,1.0
1,2,2466.0,1.0,1.0
2,3,550.0,1.0,1.0
3,4,331.0,1.0,1.0
4,5,27.0,1.0,1.0
...,...,...,...,...
136,184,2490.0,0.0,0.0
137,186,2103.0,0.0,0.0
138,187,2023.0,0.0,0.0
139,188,961.0,0.0,0.0


In [11]:
normalized_df

Unnamed: 0_level_0,A2M,ACVR1C,ADAM12,ADGRE1,ADM,ADORA2A,AKT1,ALDOA,ALDOC,ANGPT1,...,PUM1,SDHA,SF3A1,STK11IP,TBC1D10B,TBP,TFRC,TLK2,TMUB2,UBB
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
172,4355.544835,22.23879,61.900325,4.107802,2126.566542,75.498566,1642.69581,17638.759652,809.803565,69.832632,...,644.358303,3519.25303,919.722678,111.760541,273.806243,305.535471,1064.770579,263.607562,458.51568,6013.397017
165,18651.870674,26.268632,472.454667,23.222993,2249.584712,108.50087,2580.036486,25696.432448,454.180836,140.480074,...,1351.121365,2375.978709,1034.374964,108.50087,405.450621,662.807071,1902.381927,512.047967,801.383622,9493.635805
136,2204.98393,6.860665,6.183069,-0.592897,1158.097233,55.647618,550.293112,5761.010716,381.571567,41.41809,...,260.281781,570.621009,270.44573,24.478176,65.811567,128.150451,189.811738,82.073884,158.642296,2713.181356
162,15533.582084,18.800393,429.69557,111.833263,4397.547494,265.337499,3247.040996,24742.285709,1014.252106,327.359413,...,1718.200826,4214.582849,1420.49564,149.046411,511.874606,693.288703,1192.565108,594.053641,786.321574,11553.325778
129,51.008248,1.100897,5.590832,-0.626001,6.281591,0.928208,20.269464,762.317505,13.879942,0.928208,...,5.936212,9.217318,12.498424,0.582828,2.137036,1.964346,11.980354,5.072763,5.245452,57.22508
167,13889.688971,37.017234,1364.15364,61.146987,2934.781187,578.839866,4349.662149,33809.896736,657.809966,173.021296,...,1706.357407,2392.958556,2627.675242,142.310701,800.394869,896.913881,2673.741134,905.688336,809.169325,24522.135511
151,3203.758798,11.082103,40.978011,1.804063,857.960988,42.524351,674.461972,3454.265881,150.768152,59.534091,...,314.164747,292.515987,239.42498,16.752017,73.966598,99.738931,199.220139,89.429998,116.748672,1936.275428
188,21172.992598,20.073612,74.225682,76.092995,5254.151289,288.96665,3937.69579,31510.436057,3063.793417,421.545856,...,1926.599944,4047.867243,2027.434833,154.520131,645.623388,1248.765411,1463.506378,910.7818,1370.14074,27697.383394
163,636.164863,3.44236,93.111632,-1.93108,429.287442,11.502519,167.668105,1548.306222,22.249398,5.793239,...,61.206835,117.62795,91.096592,6.800759,31.317077,48.780756,66.916114,23.592758,41.056436,682.846619
119,3646.718895,10.008919,184.041546,6.7407,930.829441,50.044594,916.122458,7053.836532,322.123772,59.849249,...,422.621486,1774.846831,307.416789,29.618229,86.81205,203.650856,327.843154,137.469435,215.089621,4051.977971


In [12]:
normalized_df.reset_index()

Unnamed: 0,SampleID,A2M,ACVR1C,ADAM12,ADGRE1,ADM,ADORA2A,AKT1,ALDOA,ALDOC,...,PUM1,SDHA,SF3A1,STK11IP,TBC1D10B,TBP,TFRC,TLK2,TMUB2,UBB
0,172,4355.544835,22.23879,61.900325,4.107802,2126.566542,75.498566,1642.69581,17638.759652,809.803565,...,644.358303,3519.25303,919.722678,111.760541,273.806243,305.535471,1064.770579,263.607562,458.51568,6013.397017
1,165,18651.870674,26.268632,472.454667,23.222993,2249.584712,108.50087,2580.036486,25696.432448,454.180836,...,1351.121365,2375.978709,1034.374964,108.50087,405.450621,662.807071,1902.381927,512.047967,801.383622,9493.635805
2,136,2204.98393,6.860665,6.183069,-0.592897,1158.097233,55.647618,550.293112,5761.010716,381.571567,...,260.281781,570.621009,270.44573,24.478176,65.811567,128.150451,189.811738,82.073884,158.642296,2713.181356
3,162,15533.582084,18.800393,429.69557,111.833263,4397.547494,265.337499,3247.040996,24742.285709,1014.252106,...,1718.200826,4214.582849,1420.49564,149.046411,511.874606,693.288703,1192.565108,594.053641,786.321574,11553.325778
4,129,51.008248,1.100897,5.590832,-0.626001,6.281591,0.928208,20.269464,762.317505,13.879942,...,5.936212,9.217318,12.498424,0.582828,2.137036,1.964346,11.980354,5.072763,5.245452,57.22508
5,167,13889.688971,37.017234,1364.15364,61.146987,2934.781187,578.839866,4349.662149,33809.896736,657.809966,...,1706.357407,2392.958556,2627.675242,142.310701,800.394869,896.913881,2673.741134,905.688336,809.169325,24522.135511
6,151,3203.758798,11.082103,40.978011,1.804063,857.960988,42.524351,674.461972,3454.265881,150.768152,...,314.164747,292.515987,239.42498,16.752017,73.966598,99.738931,199.220139,89.429998,116.748672,1936.275428
7,188,21172.992598,20.073612,74.225682,76.092995,5254.151289,288.96665,3937.69579,31510.436057,3063.793417,...,1926.599944,4047.867243,2027.434833,154.520131,645.623388,1248.765411,1463.506378,910.7818,1370.14074,27697.383394
8,163,636.164863,3.44236,93.111632,-1.93108,429.287442,11.502519,167.668105,1548.306222,22.249398,...,61.206835,117.62795,91.096592,6.800759,31.317077,48.780756,66.916114,23.592758,41.056436,682.846619
9,119,3646.718895,10.008919,184.041546,6.7407,930.829441,50.044594,916.122458,7053.836532,322.123772,...,422.621486,1774.846831,307.416789,29.618229,86.81205,203.650856,327.843154,137.469435,215.089621,4051.977971


In [21]:
pd.merge(normalized_df, tabular, how="outer", left_index=True, right_on="case").loc[122]


A2M           18651.870674
ACVR1C           26.268632
ADAM12          472.454667
ADGRE1           23.222993
ADM            2249.584712
                  ...     
UBB            9493.635805
case            165.000000
surv_days       644.000000
death             1.000000
uncensored        1.000000
Name: 122, Length: 774, dtype: float64

In [31]:
data = pd.merge(normalized_df.reset_index(), tabular, how="outer", left_on='SampleID', right_on="case")
data.index = data['case']
data = data.fillna(0)

In [33]:
data

Unnamed: 0_level_0,SampleID,A2M,ACVR1C,ADAM12,ADGRE1,ADM,ADORA2A,AKT1,ALDOA,ALDOC,...,TBC1D10B,TBP,TFRC,TLK2,TMUB2,UBB,case,surv_days,death,uncensored
case,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
172,172.0,4355.544835,22.238790,61.900325,4.107802,2126.566542,75.498566,1642.695810,17638.759652,809.803565,...,273.806243,305.535471,1064.770579,263.607562,458.515680,6013.397017,172,5084.0,0.0,0.0
165,165.0,18651.870674,26.268632,472.454667,23.222993,2249.584712,108.500870,2580.036486,25696.432448,454.180836,...,405.450621,662.807071,1902.381927,512.047967,801.383622,9493.635805,165,644.0,1.0,1.0
136,136.0,2204.983930,6.860665,6.183069,-0.592897,1158.097233,55.647618,550.293112,5761.010716,381.571567,...,65.811567,128.150451,189.811738,82.073884,158.642296,2713.181356,136,3751.0,0.0,0.0
162,162.0,15533.582084,18.800393,429.695570,111.833263,4397.547494,265.337499,3247.040996,24742.285709,1014.252106,...,511.874606,693.288703,1192.565108,594.053641,786.321574,11553.325778,162,1828.0,0.0,0.0
129,129.0,51.008248,1.100897,5.590832,-0.626001,6.281591,0.928208,20.269464,762.317505,13.879942,...,2.137036,1.964346,11.980354,5.072763,5.245452,57.225080,129,414.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
175,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,175,744.0,0.0,0.0
176,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,176,265.0,1.0,1.0
179,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,179,1052.0,0.0,0.0
181,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,181,1893.0,1.0,1.0


In [20]:
data = pd.merge(normalized_df, tabular, left_index=True, right_index=True)


In [24]:
data[data['uncensored'] == 1]

Unnamed: 0,A2M,ACVR1C,ADAM12,ADGRE1,ADM,ADORA2A,AKT1,ALDOA,ALDOC,ANGPT1,...,TBC1D10B,TBP,TFRC,TLK2,TMUB2,UBB,case,surv_days,death,uncensored
165,18651.870674,26.268632,472.454667,23.222993,2249.584712,108.50087,2580.036486,25696.432448,454.180836,140.480074,...,405.450621,662.807071,1902.381927,512.047967,801.383622,9493.635805,165,644.0,1.0,1.0
129,51.008248,1.100897,5.590832,-0.626001,6.281591,0.928208,20.269464,762.317505,13.879942,0.928208,...,2.137036,1.964346,11.980354,5.072763,5.245452,57.22508,129,414.0,1.0,1.0
151,3203.758798,11.082103,40.978011,1.804063,857.960988,42.524351,674.461972,3454.265881,150.768152,59.534091,...,73.966598,99.738931,199.220139,89.429998,116.748672,1936.275428,151,7599.0,1.0,1.0
163,636.164863,3.44236,93.111632,-1.93108,429.287442,11.502519,167.668105,1548.306222,22.249398,5.793239,...,31.317077,48.780756,66.916114,23.592758,41.056436,682.846619,163,545.0,1.0,1.0
119,3646.718895,10.008919,184.041546,6.7407,930.829441,50.044594,916.122458,7053.836532,322.123772,59.849249,...,86.81205,203.650856,327.843154,137.469435,215.089621,4051.977971,119,2677.0,1.0,1.0
159,8787.383122,42.471992,436.957611,79.904934,2513.04616,241.15453,5715.002425,62512.293186,197.962674,186.444845,...,894.791285,1654.967953,2251.015567,1361.263331,1378.540074,45814.321621,159,490.0,1.0,1.0
166,312.765722,2.072038,28.056513,4.536083,223.388087,5.656103,155.066837,1535.156084,95.033739,9.240169,...,14.840271,25.816472,68.153247,22.456411,31.416575,399.007299,166,1238.0,1.0,1.0
113,671.027537,1.932206,22.752711,4.771366,286.794575,12.657921,303.829534,3576.11896,106.034733,13.91977,...,41.680444,41.364981,117.706834,35.686662,57.453554,1078.289241,113,2473.0,1.0,1.0
177,12515.07874,43.234371,1659.578167,20.628164,5509.415209,343.896923,4148.521551,48777.695294,782.457338,174.350371,...,990.434442,1408.64927,2936.828859,1132.853545,1297.878856,28135.967736,177,346.0,1.0,1.0
160,3154.995513,7.408577,309.043505,7.408577,576.810652,76.202508,945.122775,7931.411061,306.926769,64.560458,...,149.229912,182.039325,324.919028,150.28828,245.541415,4139.27791,160,1502.0,1.0,1.0


In [14]:
tabular.loc[190]

KeyError: 190

In [32]:
list(tabular.index)

[1,
 2,
 3,
 4,
 5,
 6,
 8,
 9,
 10,
 11,
 12,
 13,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 25,
 26,
 27,
 28,
 30,
 31,
 32,
 33,
 34,
 37,
 38,
 39,
 40,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 62,
 64,
 68,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 92,
 93,
 95,
 96,
 107,
 109,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 126,
 127,
 128,
 129,
 130,
 132,
 135,
 136,
 138,
 139,
 140,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 151,
 156,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 179,
 181,
 184,
 186,
 187,
 188,
 189]

In [11]:
a,b,c,d = tabular.loc[172]

In [13]:
a

172.0

In [12]:
print(a, b, c, d)

172.0 5084.0 0.0 0.0


In [28]:
merged = pd.merge(normalized_df, tabular, left_index=True, right_index=True)


In [None]:
# Initialize accelerator
accelerator = Accelerator(
    mixed_precision="bf16",
    gradient_accumulation_steps=config["grad_accum_steps"],
    project_dir=config["base_dir"] + "huggingface/",
    step_scheduler_with_optimizer=True,  # loss fluctuates a lot, so we only step the scheduler after each epoch
)
train_dataloader, eval_dataloader, test_dataloader = setup_data(
    config=config, accelerator=accelerator
)


In [None]:
steps = []
lrs = []
model = CoxResNet(8)  # Your model instance
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config["lr"])
scheduler = scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config["lr"],
    total_steps=config["num_epochs"]
    * len(train_dataloader)
    // config["grad_accum_steps"],
    anneal_strategy="cos",
    final_div_factor=25,
)

for step in range(
    config["num_epochs"] * len(train_dataloader) // config["grad_accum_steps"]
):
    scheduler.step()
    lrs.append(scheduler.get_last_lr()[0])
    steps.append(step)

plt.figure()
plt.legend()
plt.plot(steps, lrs, label="OneCycle")
plt.show()


In [None]:
lrs[15]


In [19]:
sorted([0.85, 0.15, 0.0])

[0.0, 0.15, 0.85]