In [0]:
# %pip install ctgan

In [0]:
import ctgan
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

data = pd.read_csv("./gene_expression_39_40_50.csv")
classes = pd.read_csv("./meta_50_three_class(1).csv")

# Helper Functions

In [0]:
def plot_tsne(gendata, actual_path, name, epoch = None):
    n_patients, n_genes = gendata.shape
    print(gendata.shape)
    df = pd.read_csv(actual_path)
    df.reset_index(drop=True, inplace=True)
    df = df.values
    df = df[:, 1:]
    print(df.shape)
    # df = df[:n_patients, :n_genes]
    df = df.astype(np.float64)
    labels_train = (np.ones((df.shape[0], 1))).astype(np.float64)
    labels_train = pd.DataFrame(labels_train, columns = ['label'])
    train_map = {1:"Train_Data"}
    df = pd.concat([pd.DataFrame(df), labels_train], axis = 1)
    df['label'] = df['label'].map(train_map)

    labels_test = (np.zeros((gendata.shape[0], 1))).astype(np.float64)
    labels_test = pd.DataFrame(labels_test, columns = ['label'])
    gen_map = {0: "Gen_Data"}
    gendata = pd.concat([gendata, labels_test], axis=1)
    gendata.to_csv("./gendata.csv")
    gendata['label'] = gendata['label'].map(gen_map)

    df.reset_index(drop=True, inplace=True)
    # df = df.iloc[]
    dfeatures = pd.concat([df, gendata], ignore_index=True,axis=0)
    print(dfeatures)
    labels = dfeatures[['label']]
    dfeatures.drop(['label'], axis=1, inplace=True)
    df.drop(['label'], axis=1, inplace=True)
    gendata.drop(['label'], axis=1, inplace=True)
    
    X_embedded = TSNE(n_components=2, random_state=0, perplexity=100).fit_transform(dfeatures)
    X_embedded = pd.DataFrame(X_embedded, columns = ['dim1', 'dim2'])
    X_embedded = pd.DataFrame(np.hstack([np.array(X_embedded), np.array(labels)]))
    X_embedded.columns = ['dim1','dim2','label']

    sns_fig = sns.lmplot(x='dim1', y='dim2', data=X_embedded,fit_reg=False, hue='label'
                    , markers=["x","o"],
                    palette = dict(Gen_Data= (0.568, 0.508, 0.084) ,Train_Data= (0.325, 0.843, 0.078)))

    filename = f"./tsne_plot_{label_disease_dictionary[actual_path[-5]]}"
    if epoch is not None:
        filename = filename + f"_{epoch}"
    filename = filename + name + ".png"
    plt.savefig(filename)

In [0]:
def label_def(row):
    if row['tuberculosis']==0 and row['hiv']==0:
        return 0
    elif row['tuberculosis']==1 and row['hiv']==0:
        return 1
    elif row['tuberculosis']==2 and row['hiv']==0:
        return 2
    elif row['tuberculosis']==0 and row['hiv']==1:
        return 3
    elif row['tuberculosis']==1 and row['hiv']==1:
        return 4
    elif row['tuberculosis']==2 and row['hiv']==1:
        return 5

label_disease_dictionary = dict()
label_disease_dictionary["0"] = "tuberculosis_0_hiv_0"
label_disease_dictionary["1"] = "tuberculosis_1_hiv_0"
label_disease_dictionary["2"] = "tuberculosis_2_hiv_0"
label_disease_dictionary["3"] = "tuberculosis_0_hiv_1"
label_disease_dictionary["4"] = "tuberculosis_1_hiv_1"
label_disease_dictionary["5"] = "tuberculosis_2_hiv_1"

# Data Processing

In [0]:
pca = PCA(0.9)

In [0]:
data = data.T

In [0]:
data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573
6247215025_L.AVG_Signal,119.645000,60.05559,105.129700,114.789540,82.230020,76.839167,112.205400,483.69410,125.77210,343.0028,821.5002,1391.21895,407.84015,813.954933,99.10909,901.9719,74.503880,91.982860,483.609473,81.34829,85.01229,1506.150045,62.645220,120.864900,96.903335,176.82940,109.409800,2751.390480,395.324467,100.985200,364.28840,87.358300,168.136560,4013.01190,102.604600,300.5773,258.99670,110.54770,3462.0945,315.77180,...,121.24270,84.28473,130.028027,119.57780,451.2493,115.723950,61.73062,89.593010,5553.260,1162.8270,1332.88850,289.5434,151.61500,67.95764,123.794400,497.255633,5252.228,393.23020,93.83552,124.827500,91.271140,2665.07800,457.9658,2475.4770,108.483700,96.074870,103.532100,133.67630,125.07090,200.127520,87.892940,98.088170,144.456048,5355.013,551.362157,75.919075,94.924450,95.426970,1543.5390,88.143910
6142077055_K.AVG_Signal,119.563500,78.60551,117.106600,127.688175,94.462090,96.357827,116.783100,441.63060,157.73980,187.9323,1420.1440,3125.47050,770.05345,721.950867,130.78490,2025.1680,87.367150,96.867460,881.825967,99.59898,79.52830,2078.154650,105.081100,132.970100,107.670305,189.54470,143.092400,5664.350620,949.082067,267.741000,604.53030,144.548700,283.839233,12168.38150,129.604700,409.3191,249.91810,135.88270,4685.3600,229.71780,...,136.81370,105.62160,127.140603,124.40740,608.7090,159.409050,70.50216,87.608100,13759.780,1613.2290,1282.47155,290.4000,189.79940,85.24671,145.833300,585.153933,5966.647,357.49160,101.11270,144.394500,98.093560,4934.81350,726.1572,2233.0430,114.607400,90.263480,92.061210,189.95810,226.57380,200.174270,127.669700,112.409000,166.077450,11872.410,1042.443107,84.112555,85.889970,103.347400,1850.6630,101.974700
6142077060_J.AVG_Signal,123.213900,81.73999,98.452870,114.761735,91.284820,87.890133,119.201000,381.25750,158.30070,201.4227,736.7297,1922.11200,648.16985,596.411133,105.96840,2124.5980,110.802800,123.948400,1124.015700,79.29423,76.83108,1984.488200,81.495510,118.683100,115.940800,278.37390,93.205080,5783.591340,653.534300,186.827200,521.25320,109.870900,462.407433,9859.63200,111.365800,341.7842,924.24920,131.94350,4214.2555,229.88910,...,146.22730,110.83210,110.870677,99.73404,539.8279,121.706295,78.89526,85.985260,2181.246,1458.2520,2368.64920,234.0573,150.14010,92.12841,127.132500,516.829933,5920.334,333.57100,111.88550,127.109450,78.873330,2943.64350,720.5342,2126.0330,119.943100,98.709310,108.762200,176.66920,199.94870,206.967910,105.469000,108.379200,215.601587,3635.568,905.905538,108.985700,104.408800,95.085010,2091.1970,89.287840
6142077055_I.AVG_Signal,129.417000,59.07066,111.201100,118.669310,103.244200,86.048347,116.871700,573.88670,132.90580,230.1286,1073.0330,1921.91410,534.38655,722.428300,113.17880,1811.6460,88.948520,85.053340,1572.604533,85.11079,72.37727,2602.577795,148.429100,144.800900,106.308075,290.24400,147.874600,4750.135160,650.359400,169.824000,415.28320,101.325400,1052.893500,5758.27700,120.805500,332.9302,274.15650,119.52280,5175.8665,262.58590,...,143.26040,88.24596,135.901967,96.61730,457.9533,131.134650,92.72658,90.405270,2477.022,1529.3450,2266.58840,260.8342,162.56940,79.96722,178.492900,643.190000,9134.733,405.16130,96.50447,133.606150,96.858920,2292.65445,531.2582,2458.5720,141.202700,96.426910,113.459400,173.82240,154.47620,226.564150,102.399200,144.186400,202.411227,3846.824,844.773790,84.867165,122.212200,109.834900,2425.5130,80.197620
6142077060_K.AVG_Signal,105.425400,76.26342,127.926000,109.444825,92.184070,89.378613,105.609000,507.39920,146.20120,201.8669,727.2063,2040.71945,588.65280,610.464600,101.62680,1152.3130,90.135470,95.069260,987.288333,89.43764,80.65396,2154.715150,102.888500,131.948300,96.557350,234.85730,98.758270,3489.721220,518.925867,90.492530,277.44020,108.872800,508.667900,3248.65500,110.402000,296.4421,178.06170,146.37140,5332.9525,277.87490,...,117.23260,104.40630,104.468700,159.51240,541.9583,139.178150,78.40203,110.282200,2451.522,1591.6590,2449.08485,265.2744,160.34570,97.42243,173.925400,620.247167,8641.364,326.88500,95.64765,135.643000,96.122340,1997.75205,461.8423,1872.5820,124.725800,96.374530,106.004900,184.42570,189.10640,224.066515,90.843880,124.378800,198.897760,2194.518,485.782218,92.057490,111.776100,108.756500,1848.8670,79.913120
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5685521033_H.AVG_Signal,-2.189815,-46.03613,5.749578,-12.381135,-35.507860,-19.056880,16.724690,91.31330,196.48990,103.4944,580.3944,2988.64500,408.34990,243.923583,-31.70149,612.9030,-6.780165,-23.543180,1314.362713,-40.17883,-35.57662,1989.938984,-15.461580,12.767190,-11.483525,99.92490,-15.143070,2075.880986,410.396880,14.539320,84.07471,34.214870,653.081333,3262.37795,18.328810,182.5165,60.58614,49.99961,4443.8295,33.20175,...,70.38071,-16.02555,-10.981199,31.53091,510.5679,39.816028,-28.99254,-27.877710,5807.106,1262.3340,1878.97200,305.1814,71.75225,-41.35888,45.078830,996.229687,10601.830,213.91500,-35.90978,33.817135,-1.273051,1463.31810,331.7239,1076.6720,43.981830,-12.465280,-8.110441,66.41614,110.42260,110.155040,12.599500,6.945223,133.361137,1148.166,447.759818,-17.679555,32.440620,-7.966771,536.7505,-18.575380
5685521033_I.AVG_Signal,25.470250,-42.13654,11.191050,-11.422420,-2.354719,-20.342854,7.849306,128.46680,40.79843,102.2480,371.9373,1300.58795,334.49135,163.513623,-36.61034,622.0648,-12.945520,-8.340496,659.970398,-36.96067,-23.04228,990.979862,5.615935,3.647199,23.322283,72.27074,-44.022660,1470.075340,218.503877,1.648164,59.06745,6.764814,286.405233,1578.32235,-4.809433,177.4125,66.92413,38.17730,4460.5135,74.45583,...,51.75272,-27.47854,-10.322463,-20.05265,227.7927,47.411535,-17.86875,-27.597810,2183.061,850.8649,1254.33166,190.5734,33.63958,-28.16190,28.170690,663.856963,9544.950,173.36150,-28.70290,-1.535435,-24.503820,1024.14025,214.9556,989.5558,5.839520,2.862215,-13.436850,63.25101,75.10828,63.852785,1.056084,35.440900,56.424470,1123.725,260.548230,-8.033682,-17.464670,-1.763262,716.7267,9.340691
5685521033_J.AVG_Signal,5.203332,-37.44641,-7.461519,-2.290285,-32.947030,-17.947770,-3.632384,151.23410,67.00423,140.1396,454.0324,2020.61855,281.87040,202.314533,-20.14680,285.4071,-35.914700,40.765270,634.781051,-34.78794,-34.44548,977.607143,-9.750414,-2.055379,5.803905,93.59048,-0.261585,1396.325664,257.997140,26.054540,67.84500,12.149200,533.598633,2820.03845,15.263220,253.5119,110.30980,43.71063,4251.8785,85.12035,...,60.84709,-28.83603,-4.920416,-19.62059,168.6748,21.935474,-41.69322,-29.897310,4859.993,907.1992,1787.07092,155.8786,36.87611,-41.74828,39.664970,855.127804,12616.890,175.71950,-39.27529,0.639645,4.070309,1131.92520,215.9942,1199.4110,30.978370,-29.115160,-11.689390,69.64512,60.09529,58.338865,51.640040,14.996640,76.147788,1595.139,233.929300,-34.181275,10.931870,-4.655565,580.3944,-8.010675
5685521033_K.AVG_Signal,-11.847170,-48.77342,41.844610,29.301540,-50.350070,-18.222930,6.361732,7.26864,456.90490,212.5967,240.9312,1553.50290,367.43405,329.663760,-16.15349,1198.0950,-42.081510,-29.159050,1012.875630,-18.89009,-24.05211,2656.682369,-24.556000,21.960930,16.408357,115.41050,28.417200,1574.131868,449.429540,-1.567851,105.34070,21.978490,136.092857,3152.96870,32.736020,194.3142,255.83760,83.12910,3297.5555,-13.12783,...,229.70940,-18.85476,-20.310017,-13.61805,583.0040,58.436177,-39.68929,-7.578042,1861.644,821.5377,1981.77985,544.7303,58.56222,-23.44128,6.188788,777.306373,6261.345,143.57140,-54.88902,55.307040,30.162160,1455.17080,342.2286,773.6067,3.814982,16.976040,14.148190,58.41484,89.83482,78.295545,8.890541,38.712060,87.952589,1108.522,351.473153,-23.262255,24.555860,-2.064317,509.8944,-22.590360


In [0]:
names = data.index.values
data.reset_index(drop=True, inplace=True)
data = pd.concat([data, pd.DataFrame(names, columns=["host_name"])], axis=1)

In [0]:
data.shape

(1028, 575)

In [0]:
classes['label'] = classes.apply(lambda x: label_def(x), axis=1)
data = pd.merge(data, classes, on='host_name')
data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573,host_name,tuberculosis,hiv,label
0,22.977790,-46.11452,3.379508,-4.298570,-30.656710,-21.974197,6.757988,50.23626,220.99770,203.1971,708.7916,1661.44655,470.32645,238.714727,3.026244,1920.9470,-23.456190,9.644024,840.716017,-46.44247,-47.323830,2301.860000,-33.646930,27.443540,18.534980,46.75706,2.794634,2873.504760,631.742823,21.021710,361.70300,73.888920,129.573850,7617.18050,-6.016032,294.3749,324.19160,50.519650,2569.3675,3.682932,...,427.2864,44.471532,-33.76612,-7.858823,1691.314,837.6642,1254.680995,255.5322,49.75186,-26.93139,-11.995730,545.262778,3484.594,199.45760,-55.05777,35.053051,3.399793,2509.53880,536.0202,454.5185,-5.451852,-16.099670,-16.331790,42.69594,115.29500,74.278450,-9.138244,11.752660,59.211828,1428.479,800.344655,-18.816740,-18.305200,8.300746,659.5989,4.898625,5483347021_A.AVG_Signal,1,0,1
1,24.390000,-49.27365,4.602244,-17.406060,-26.892460,-16.237700,-2.187479,45.99354,168.13400,205.0580,405.6895,2743.80150,395.42655,196.898758,5.178334,1014.8630,-19.024100,1.942063,1012.879304,-44.20679,-36.457795,1678.046375,-28.453730,33.461790,8.595460,109.82810,69.420520,3046.885720,499.203333,25.955820,184.66710,39.465270,219.375267,8029.15500,26.702920,196.4899,419.51500,51.560810,3544.6180,23.892720,...,527.6553,35.925996,-26.31075,-3.956610,3886.890,1194.9450,1717.704100,205.8496,26.72584,-30.33980,6.085990,828.556110,4562.318,144.43700,-48.45786,39.471365,-1.511332,1786.24300,501.2137,896.3716,76.160110,-5.510161,-19.911600,80.46590,96.16717,110.483339,9.100771,-6.079873,84.302540,1199.411,357.877343,-26.722555,-9.039674,5.994009,651.8876,-13.996800,5483347021_B.AVG_Signal,1,0,1
2,25.728400,-41.38828,23.113950,0.701020,-13.579480,-23.266679,8.795504,103.67210,98.01786,209.2793,742.9699,2569.88400,480.45970,210.410058,-10.745800,434.5760,-32.961940,-2.481033,853.980037,-31.64151,-42.335595,1410.080053,-39.066760,7.241544,6.274940,50.48791,9.098358,2229.952360,454.817027,12.847770,176.55140,26.906610,251.809500,5854.09705,4.223186,266.9550,99.26867,37.875670,4210.8510,22.267560,...,558.2650,58.951976,-31.40832,-20.634650,7450.395,1110.9310,1372.880350,206.0475,39.94660,-31.13787,22.867580,780.782468,10513.320,198.23080,-57.15127,24.871980,-15.068220,2032.32150,405.7916,670.4788,5.320835,-5.530854,-3.949874,58.19359,90.01529,76.602330,10.856310,21.095920,91.501303,1535.089,566.085672,-24.959640,-16.222590,6.536697,474.1375,9.648779,5483347021_C.AVG_Signal,1,1,4
3,40.313390,-45.15408,-7.856186,-4.906975,-19.610780,-11.327387,1.588020,36.84473,123.86790,196.4014,517.2504,5174.66450,443.81340,170.045640,-2.900939,1044.2460,-30.381400,-5.812512,1158.235857,-40.83022,-48.552395,1366.329210,-18.841790,3.316790,15.477540,81.11749,-10.552080,2368.191980,477.770000,9.402896,196.84120,44.071900,157.337493,7781.81840,3.231941,264.9986,358.90110,52.497440,2931.9080,-11.530910,...,584.9387,25.461175,-45.08366,-13.164050,8709.294,690.2562,783.225290,327.9715,70.15594,-27.17027,8.640390,559.740580,4628.971,224.20240,-53.05270,29.522220,-5.563776,2314.01135,503.0981,657.4658,-7.153355,-7.451403,-8.758663,18.18983,93.95804,77.390298,-4.603375,10.479710,46.410275,1507.451,668.496958,-33.981485,-1.506804,8.153008,330.3208,28.543990,5483347021_D.AVG_Signal,1,1,4
4,21.593800,-46.83287,19.975260,4.177935,-20.627690,-8.348949,-20.658600,114.62120,17.62079,273.0651,334.9753,1157.34695,375.40125,239.008843,16.301970,1167.5910,-33.888050,-2.578857,1145.476125,-29.43795,-38.146350,2017.807600,-28.550190,31.323870,24.327049,58.73265,15.156780,2446.091226,428.722387,18.052270,234.15680,39.099810,158.376267,4593.32130,-3.096215,240.0276,158.10380,-4.068097,3195.4705,-9.743137,...,390.5138,33.559935,-32.87948,-6.520766,1592.462,891.8818,2098.356200,342.5671,55.32413,-23.61512,21.325750,668.428340,5506.148,218.55710,-65.45430,65.706540,-2.885177,2050.79240,328.7022,814.5776,-6.859263,-1.464172,-22.763470,62.23862,120.07160,79.446618,27.224200,-6.651957,71.925212,1746.376,277.686380,-28.985250,-21.970280,-1.715927,961.7714,9.005090,5483347021_E.AVG_Signal,2,0,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
532,-2.189815,-46.03613,5.749578,-12.381135,-35.507860,-19.056880,16.724690,91.31330,196.48990,103.4944,580.3944,2988.64500,408.34990,243.923583,-31.701490,612.9030,-6.780165,-23.543180,1314.362713,-40.17883,-35.576620,1989.938984,-15.461580,12.767190,-11.483525,99.92490,-15.143070,2075.880986,410.396880,14.539320,84.07471,34.214870,653.081333,3262.37795,18.328810,182.5165,60.58614,49.999610,4443.8295,33.201750,...,510.5679,39.816028,-28.99254,-27.877710,5807.106,1262.3340,1878.972000,305.1814,71.75225,-41.35888,45.078830,996.229687,10601.830,213.91500,-35.90978,33.817135,-1.273051,1463.31810,331.7239,1076.6720,43.981830,-12.465280,-8.110441,66.41614,110.42260,110.155040,12.599500,6.945223,133.361137,1148.166,447.759818,-17.679555,32.440620,-7.966771,536.7505,-18.575380,5685521033_H.AVG_Signal,2,1,5
533,25.470250,-42.13654,11.191050,-11.422420,-2.354719,-20.342854,7.849306,128.46680,40.79843,102.2480,371.9373,1300.58795,334.49135,163.513623,-36.610340,622.0648,-12.945520,-8.340496,659.970398,-36.96067,-23.042280,990.979862,5.615935,3.647199,23.322283,72.27074,-44.022660,1470.075340,218.503877,1.648164,59.06745,6.764814,286.405233,1578.32235,-4.809433,177.4125,66.92413,38.177300,4460.5135,74.455830,...,227.7927,47.411535,-17.86875,-27.597810,2183.061,850.8649,1254.331660,190.5734,33.63958,-28.16190,28.170690,663.856963,9544.950,173.36150,-28.70290,-1.535435,-24.503820,1024.14025,214.9556,989.5558,5.839520,2.862215,-13.436850,63.25101,75.10828,63.852785,1.056084,35.440900,56.424470,1123.725,260.548230,-8.033682,-17.464670,-1.763262,716.7267,9.340691,5685521033_I.AVG_Signal,0,1,3
534,5.203332,-37.44641,-7.461519,-2.290285,-32.947030,-17.947770,-3.632384,151.23410,67.00423,140.1396,454.0324,2020.61855,281.87040,202.314533,-20.146800,285.4071,-35.914700,40.765270,634.781051,-34.78794,-34.445480,977.607143,-9.750414,-2.055379,5.803905,93.59048,-0.261585,1396.325664,257.997140,26.054540,67.84500,12.149200,533.598633,2820.03845,15.263220,253.5119,110.30980,43.710630,4251.8785,85.120350,...,168.6748,21.935474,-41.69322,-29.897310,4859.993,907.1992,1787.070920,155.8786,36.87611,-41.74828,39.664970,855.127804,12616.890,175.71950,-39.27529,0.639645,4.070309,1131.92520,215.9942,1199.4110,30.978370,-29.115160,-11.689390,69.64512,60.09529,58.338865,51.640040,14.996640,76.147788,1595.139,233.929300,-34.181275,10.931870,-4.655565,580.3944,-8.010675,5685521033_J.AVG_Signal,0,1,3
535,-11.847170,-48.77342,41.844610,29.301540,-50.350070,-18.222930,6.361732,7.26864,456.90490,212.5967,240.9312,1553.50290,367.43405,329.663760,-16.153490,1198.0950,-42.081510,-29.159050,1012.875630,-18.89009,-24.052110,2656.682369,-24.556000,21.960930,16.408357,115.41050,28.417200,1574.131868,449.429540,-1.567851,105.34070,21.978490,136.092857,3152.96870,32.736020,194.3142,255.83760,83.129100,3297.5555,-13.127830,...,583.0040,58.436177,-39.68929,-7.578042,1861.644,821.5377,1981.779850,544.7303,58.56222,-23.44128,6.188788,777.306373,6261.345,143.57140,-54.88902,55.307040,30.162160,1455.17080,342.2286,773.6067,3.814982,16.976040,14.148190,58.41484,89.83482,78.295545,8.890541,38.712060,87.952589,1108.522,351.473153,-23.262255,24.555860,-2.064317,509.8944,-22.590360,5685521033_K.AVG_Signal,0,0,0


In [0]:
labels = list(set(data['label'].values))

In [0]:
datacl = data.drop(["tuberculosis", "hiv"], axis=1)
datacl.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573,host_name,label
0,22.97779,-46.11452,3.379508,-4.29857,-30.65671,-21.974197,6.757988,50.23626,220.9977,203.1971,708.7916,1661.44655,470.32645,238.714727,3.026244,1920.947,-23.45619,9.644024,840.716017,-46.44247,-47.32383,2301.86,-33.64693,27.44354,18.53498,46.75706,2.794634,2873.50476,631.742823,21.02171,361.703,73.88892,129.57385,7617.1805,-6.016032,294.3749,324.1916,50.51965,2569.3675,3.682932,...,-13.816249,-1.280018,427.2864,44.471532,-33.76612,-7.858823,1691.314,837.6642,1254.680995,255.5322,49.75186,-26.93139,-11.99573,545.262778,3484.594,199.4576,-55.05777,35.053051,3.399793,2509.5388,536.0202,454.5185,-5.451852,-16.09967,-16.33179,42.69594,115.295,74.27845,-9.138244,11.75266,59.211828,1428.479,800.344655,-18.81674,-18.3052,8.300746,659.5989,4.898625,5483347021_A.AVG_Signal,1
1,24.39,-49.27365,4.602244,-17.40606,-26.89246,-16.2377,-2.187479,45.99354,168.134,205.058,405.6895,2743.8015,395.42655,196.898758,5.178334,1014.863,-19.0241,1.942063,1012.879304,-44.20679,-36.457795,1678.046375,-28.45373,33.46179,8.59546,109.8281,69.42052,3046.88572,499.203333,25.95582,184.6671,39.46527,219.375267,8029.155,26.70292,196.4899,419.515,51.56081,3544.618,23.89272,...,16.244684,-25.9005,527.6553,35.925996,-26.31075,-3.95661,3886.89,1194.945,1717.7041,205.8496,26.72584,-30.3398,6.08599,828.55611,4562.318,144.437,-48.45786,39.471365,-1.511332,1786.243,501.2137,896.3716,76.16011,-5.510161,-19.9116,80.4659,96.16717,110.483339,9.100771,-6.079873,84.30254,1199.411,357.877343,-26.722555,-9.039674,5.994009,651.8876,-13.9968,5483347021_B.AVG_Signal,1
2,25.7284,-41.38828,23.11395,0.70102,-13.57948,-23.266679,8.795504,103.6721,98.01786,209.2793,742.9699,2569.884,480.4597,210.410058,-10.7458,434.576,-32.96194,-2.481033,853.980037,-31.64151,-42.335595,1410.080053,-39.06676,7.241544,6.27494,50.48791,9.098358,2229.95236,454.817027,12.84777,176.5514,26.90661,251.8095,5854.09705,4.223186,266.955,99.26867,37.87567,4210.851,22.26756,...,-5.160167,-3.680539,558.265,58.951976,-31.40832,-20.63465,7450.395,1110.931,1372.88035,206.0475,39.9466,-31.13787,22.86758,780.782468,10513.32,198.2308,-57.15127,24.87198,-15.06822,2032.3215,405.7916,670.4788,5.320835,-5.530854,-3.949874,58.19359,90.01529,76.60233,10.85631,21.09592,91.501303,1535.089,566.085672,-24.95964,-16.22259,6.536697,474.1375,9.648779,5483347021_C.AVG_Signal,4
3,40.31339,-45.15408,-7.856186,-4.906975,-19.61078,-11.327387,1.58802,36.84473,123.8679,196.4014,517.2504,5174.6645,443.8134,170.04564,-2.900939,1044.246,-30.3814,-5.812512,1158.235857,-40.83022,-48.552395,1366.32921,-18.84179,3.31679,15.47754,81.11749,-10.55208,2368.19198,477.77,9.402896,196.8412,44.0719,157.337493,7781.8184,3.231941,264.9986,358.9011,52.49744,2931.908,-11.53091,...,-7.187602,-23.82788,584.9387,25.461175,-45.08366,-13.16405,8709.294,690.2562,783.22529,327.9715,70.15594,-27.17027,8.64039,559.74058,4628.971,224.2024,-53.0527,29.52222,-5.563776,2314.01135,503.0981,657.4658,-7.153355,-7.451403,-8.758663,18.18983,93.95804,77.390298,-4.603375,10.47971,46.410275,1507.451,668.496958,-33.981485,-1.506804,8.153008,330.3208,28.54399,5483347021_D.AVG_Signal,4
4,21.5938,-46.83287,19.97526,4.177935,-20.62769,-8.348949,-20.6586,114.6212,17.62079,273.0651,334.9753,1157.34695,375.40125,239.008843,16.30197,1167.591,-33.88805,-2.578857,1145.476125,-29.43795,-38.14635,2017.8076,-28.55019,31.32387,24.327049,58.73265,15.15678,2446.091226,428.722387,18.05227,234.1568,39.09981,158.376267,4593.3213,-3.096215,240.0276,158.1038,-4.068097,3195.4705,-9.743137,...,-7.834397,-11.38081,390.5138,33.559935,-32.87948,-6.520766,1592.462,891.8818,2098.3562,342.5671,55.32413,-23.61512,21.32575,668.42834,5506.148,218.5571,-65.4543,65.70654,-2.885177,2050.7924,328.7022,814.5776,-6.859263,-1.464172,-22.76347,62.23862,120.0716,79.446618,27.2242,-6.651957,71.925212,1746.376,277.68638,-28.98525,-21.97028,-1.715927,961.7714,9.00509,5483347021_E.AVG_Signal,2


# Multi Model Approach (Reduce After Split)

## Model 0

### label = 0

In [0]:
# data_0 = datacl[datacl['label']==0]
# host_names_0 = data_0.host_name.values
# label = data_0.label.values
# data_0 = data_0.drop(["host_name", "label"], axis=1)
# data_0.head()
# data_0.shape
# # Perform PCA
# data_0 = pca.fit_transform(data_0)
# data_0 = pd.DataFrame(data_0)
# columns = list(data_0.columns.values)
# model0 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model0.fit(data_0, columns, epochs = 1000)
# samples1 = model0.sample(100)
# samples2 = model0.sample(50)
# samples3 = model0.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model0.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_0.to_csv("./data0.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data0.csv", f"_model_0_{sample.shape[0]}")

## Model 1

### label 1

In [0]:
# data_1 = datacl[datacl['label']==1]
# host_names_1 = data_1.host_name.values
# label = data_1.label.values
# data_1 = data_1.drop(["host_name", "label"], axis=1)
# data_1.head()
# data_1.shape
# # Perform PCA
# data_1 = pca.fit_transform(data_1)
# data_1 = pd.DataFrame(data_1)
# columns = list(data_1.columns.values)
# model1 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model1.fit(data_1, columns, epochs = 1000)
# samples1 = model1.sample(100)
# samples2 = model1.sample(50)
# samples3 = model1.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model1.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_1.to_csv("./data1.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data1.csv", f"_model_1_{sample.shape[0]}")

## Model 2
### label 2

In [0]:
# data_2 = datacl[datacl['label']==2]
# host_names_2 = data_2.host_name.values
# label = data_2.label.values
# data_2 = data_2.drop(["host_name", "label"], axis=1)
# data_2.head()
# data_2.shape
# # Perform PCA
# data_2 = pca.fit_transform(data_2)
# data_2 = pd.DataFrame(data_2)
# columns = list(data_2.columns.values)
# model2 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model2.fit(data_2, columns, epochs = 1000)
# samples1 = model2.sample(100)
# samples2 = model2.sample(50)
# samples3 = model2.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model2.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_2.to_csv("./data2.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data2.csv", f"_model_2_{sample.shape[0]}")

## Model 3
# label 3

In [0]:
# data_3 = datacl[datacl['label']==3]
# host_names_3 = data_3.host_name.values
# label = data_3.label.values
# data_3 = data_3.drop(["host_name", "label"], axis=1)
# data_3.head()
# data_3.shape
# # Perform PCA
# data_3 = pca.fit_transform(data_3)
# data_3 = pd.DataFrame(data_3)
# columns = list(data_3.columns.values)
# model3 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model3.fit(data_3, columns, epochs = 1000)
# samples1 = model3.sample(100)
# samples2 = model3.sample(50)
# samples3 = model3.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model3.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_3.to_csv("./data3.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data3.csv", f"_model_3_{sample.shape[0]}")

## Model 4

### label 4

In [0]:
# data_4 = datacl[datacl['label']==4]
# host_names_4 = data_4.host_name.values
# label = data_4.label.values
# data_4 = data_4.drop(["host_name", "label"], axis=1)
# data_4.head()
# data_4.shape
# # Perform PCA
# data_4 = pca.fit_transform(data_4)
# data_4 = pd.DataFrame(data_4)
# columns = list(data_4.columns.values)
# model4 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model4.fit(data_4, columns, epochs = 1000)
# samples1 = model4.sample(100)
# samples2 = model4.sample(50)
# samples3 = model4.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model4.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_4.to_csv("./data4.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data4.csv", f"_model_4_{sample.shape[0]}")

## Model 5
### label 5

In [0]:
# data_5 = datacl[datacl['label']==5]
# host_names_5 = data_5.host_name.values
# label = data_5.label.values
# data_5 = data_5.drop(["host_name", "label"], axis=1)
# data_5.head()
# data_5.shape
# # Perform PCA
# data_5 = pca.fit_transform(data_5)
# data_5 = pd.DataFrame(data_5)
# columns = list(data_5.columns.values)
# model5 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model5.fit(data_5, columns, epochs = 1000)
# samples1 = model5.sample(100)
# samples2 = model5.sample(50)
# samples3 = model5.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model5.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_5.to_csv("./data5.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data5.csv", f"_model_5_{sample.shape[0]}")

# Multi Model Approach (Reduce Before Split)

In [0]:
# host_names = datacl.host_name
# labels = datacl.label
# datacl = datacl.drop(["host_name", "label"], axis=1)
# data_reduced = pca.fit_transform(datacl)
# data_reduced = pd.DataFrame(data_reduced)
# data_reduced = pd.concat([data_reduced, host_names, labels], axis=1)
# data_reduced.head()

## Model 0 
### label 0

In [0]:
# data_0 = data_reduced[data_reduced['label']==0]
# host_names_0 = data_0.host_name.values
# label = data_0.label.values
# data_0 = data_0.drop(["host_name", "label"], axis=1)
# data_0.head()
# data_0.shape
# columns = list(data_0.columns.values)
# model0 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model0.fit(data_0, columns, epochs = 1000)
# samples1 = model0.sample(100)
# samples2 = model0.sample(50)
# samples3 = model0.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model0.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_0.to_csv("./data0.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data0.csv", f"_model_0_{sample.shape[0]}")

## Model 1
### label 1

In [0]:
# data_1 = data_reduced[data_reduced['label']==1]
# host_names_1 = data_1.host_name.values
# label = data_1.label.values
# data_1 = data_1.drop(["host_name", "label"], axis=1)
# data_1.head()
# data_1.shape
# columns = list(data_1.columns.values)
# model1 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model1.fit(data_1, columns, epochs = 1000)
# samples1 = model1.sample(100)
# samples2 = model1.sample(50)
# samples3 = model1.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model1.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_1.to_csv("./data1.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data1.csv", f"_model_1_{sample.shape[0]}")

## Model 2
### label 2

In [0]:
# data_2 = data_reduced[data_reduced['label']==2]
# host_names_2 = data_2.host_name.values
# label = data_2.label.values
# data_2 = data_2.drop(["host_name", "label"], axis=1)
# data_2.head()
# data_2.shape
# columns = list(data_2.columns.values)
# model2 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model2.fit(data_2, columns, epochs = 1000)
# samples1 = model2.sample(100)
# samples2 = model2.sample(50)
# samples3 = model2.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model2.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_2.to_csv("./data2.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data2.csv", f"_model_2_{sample.shape[0]}")

## Model 3
### label 3

In [0]:
# data_3 = data_reduced[data_reduced['label']==3]
# host_names_3 = data_3.host_name.values
# label = data_3.label.values
# data_3 = data_3.drop(["host_name", "label"], axis=1)
# data_3.head()
# data_3.shape
# columns = list(data_3.columns.values)
# model3 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model3.fit(data_3, columns, epochs = 1000)
# samples1 = model3.sample(100)
# samples2 = model3.sample(50)
# samples3 = model3.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model3.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_3.to_csv("./data3.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data3.csv", f"_model_3_{sample.shape[0]}")

## Model 4
### label 4

In [0]:
# data_4 = data_reduced[data_reduced['label']==4]
# host_names_4 = data_4.host_name.values
# label = data_4.label.values
# data_4 = data_4.drop(["host_name", "label"], axis=1)
# data_4.head()
# data_4.shape
# columns = list(data_4.columns.values)
# model4 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model4.fit(data_4, columns, epochs = 1000)
# samples1 = model4.sample(100)
# samples2 = model4.sample(50)
# samples3 = model4.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model4.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_4.to_csv("./data4.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data4.csv", f"_model_4_{sample.shape[0]}")

## Model 5
### label 5

In [0]:
# data_5 = data_reduced[data_reduced['label']==5]
# host_names_5 = data_5.host_name.values
# label = data_5.label.values
# data_5 = data_5.drop(["host_name", "label"], axis=1)
# data_5.head()
# data_5.shape
# columns = list(data_5.columns.values)
# model5 = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
# model5.fit(data_5, columns, epochs = 1000)
# samples1 = model5.sample(100)
# samples2 = model5.sample(50)
# samples3 = model5.sample(150)

# for sample in [samples1, samples2, samples3]:
#     size = sample.shape[0]
#     name = f"./sample_{size}_model5.csv"
#     sample = pd.DataFrame(sample)
#     sample.to_csv(name)

# data_5.to_csv("./data5.csv")
# for sample in [samples1, samples2, samples3]:
#   sample.reset_index(drop=True, inplace=True)
#   plot_tsne(sample, "./data5.csv", f"_model_5_{sample.shape[0]}")

# Single Model Approach


In [0]:
host_names = datacl.host_name
labels = datacl.label
datacl = datacl.drop(["host_name", "label"], axis=1)
data_reduced = pca.fit_transform(datacl)
data_reduced = pd.DataFrame(data_reduced)
data_reduced = pd.concat([data_reduced, labels], axis=1)

In [0]:
columns = list(data_reduced.columns.values)
single_model = ctgan.CTGANSynthesizer(batch_size=10, gen_dim=(256, 256), dis_dim=(256, 256))
single_model.fit(data_reduced, columns, epochs = 1000)
samples1 = single_model.sample(100)
samples2 = single_model.sample(50)
samples3 = single_model.sample(150)

for sample in [samples1, samples2, samples3]:
    size = sample.shape[0]
    name = f"./sample_{size}_model0.csv"
    sample = pd.DataFrame(sample)
    sample.to_csv(name)

data_reduced.to_csv("./data_reduced.csv")

for sample in [samples1, samples2, samples3]:
    sample.reset_index(drop=True, inplace=True)
    plot_tsne(sample, "./data_reduced.csv", f"_single_model_{sample.shape[0]}")

Epoch 1, Loss G: 5.3704, Loss D: -0.0129
Epoch 2, Loss G: 6.0373, Loss D: 0.0190
Epoch 3, Loss G: 6.2596, Loss D: 0.0441
Epoch 4, Loss G: 5.2957, Loss D: -0.0485
Epoch 5, Loss G: 6.6037, Loss D: 0.1456
Epoch 6, Loss G: 6.0941, Loss D: -0.1163
Epoch 7, Loss G: 6.6973, Loss D: 0.0143
Epoch 8, Loss G: 5.6214, Loss D: -0.0304
Epoch 9, Loss G: 5.9446, Loss D: -0.1219
Epoch 10, Loss G: 5.7318, Loss D: -0.0541
Epoch 11, Loss G: 6.7497, Loss D: -0.0541
Epoch 12, Loss G: 5.6045, Loss D: -0.2013
Epoch 13, Loss G: 6.3807, Loss D: -0.0960
Epoch 14, Loss G: 7.2708, Loss D: -0.1756
Epoch 15, Loss G: 7.1211, Loss D: -0.0708
Epoch 16, Loss G: 7.0056, Loss D: -0.1111
Epoch 17, Loss G: 5.4749, Loss D: -0.2327
Epoch 18, Loss G: 6.9754, Loss D: -0.2685
Epoch 19, Loss G: 5.3179, Loss D: -0.1376
Epoch 20, Loss G: 7.3985, Loss D: -0.1862
Epoch 21, Loss G: 5.4182, Loss D: 0.0452
Epoch 22, Loss G: 7.0432, Loss D: -0.0767
Epoch 23, Loss G: 6.1908, Loss D: -0.0786
Epoch 24, Loss G: 5.7463, Loss D: -0.3301
Epoch 

AttributeError: ignored