In [1]:
import pandas as pd
from datetime import timedelta
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
import glob
from pathlib import Path

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

import requests
import yfinance as yf
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
# Import CSV Containing IPO Information - Ticker, Industry, # of Shares Offered, Offer Price, 1st Day % Return (IPO Scoop)
ipo_csv = pd.read_csv('IPO Project - Sheet2.csv')

In [3]:
# Drop unnecessary columns
ipo_df = ipo_csv.drop(columns=['1st Day Close'])

# Assign stock ticker as index
ipo_df.set_index('Symbol',inplace=True)

# Create date columns for parsing that represent a 90 day period starting 10days before IPO Offer date
ipo_df['Offer Date'] = pd.to_datetime(ipo_df['Offer Date'])
ten_days = pd.to_timedelta(10,'d')
three_months = pd.to_timedelta(100,'d')

ipo_df['T-10D'] = ipo_df['Offer Date']-ten_days
ipo_df['T-90D'] = ipo_df['Offer Date']-three_months
ipo_df['T+100D'] = ipo_df['Offer Date']+three_months

In [4]:
ipo_df.head(3)

Unnamed: 0_level_0,Industry,Offer Date,Shares (millions),Offer Price,Return,T-10D,T-90D,T+100D
Symbol,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ABOS,Health Care,2021-07-01,10.0,16.0,0.26,2021-06-21,2021-03-23,2021-10-09
ABSI,Health Care,2021-07-22,12.5,16.0,0.35,2021-07-12,2021-04-13,2021-10-30
ACT,Financials,2021-09-16,13.3,19.0,0.08,2021-09-06,2021-06-08,2021-12-25


In [5]:
ipo_df['T-10D'] = ipo_df['T-10D'].astype(str)
ipo_df['T-90D'] = ipo_df['T-90D'].astype(str)
ipo_df['T+100D'] = ipo_df['T+100D'].astype(str)

In [5]:
# Use Yahoo Finance data to measure returns for above-mentioned 90 day period
for ticker in ipo_df.index:
    data = yf.download('SPY',ipo_df.loc[ticker,'T-90D'],ipo_df.loc[ticker,'T-10D'], progress=False)
    spy_return = (data.iloc[-1,-2] - data.iloc[0,-2])/data.iloc[0,-2]
    ipo_df.loc[ticker,'SPY 90D Return'] = spy_return

    #Use Yahoo Finance data to measure returns for +100 days
    #data2 = yf.download(ticker,ipo_df.loc[ticker,'Offer Date'],ipo_df.loc[ticker,'T+100D'], progress=False)
    #stock_return = (data2.iloc[-1,-2] - data2.iloc[0,-2])/data2.iloc[0,-2]
    #ipo_df.loc[ticker,'100 day Return'] = stock_return    
    
    # Encode Dffer Date to week of year
    ipo_df.loc[ticker,'Offer Date'] = ipo_df.loc[ticker,'Offer Date'].isocalendar()[1]

In [6]:
# Create binary encoded column indicating whether stock saw 10% returns on day of IPO    
returns_bin = []
for pct in ipo_df.Return:
    if pct >= .1:
        returns_bin.append(1)
    else:
        returns_bin.append(0)
        
ipo_df['10% Returns?'] = returns_bin

#Encoding "Industry" Column
le = LabelEncoder()
le.fit(ipo_df["Industry"])
ipo_df["Industry"] = le.transform(ipo_df["Industry"])

# Drop unecessary date columns
ipo_df.drop(columns=['T-10D','T-90D','T+100D'], inplace=True)

ipo_df.head(3)

Unnamed: 0_level_0,Industry,Offer Date,Shares (millions),Offer Price,Return,SPY 90D Return,10% Returns?
Symbol,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
ABOS,4,26,10.0,16.0,0.26,0.068748,1
ABSI,4,29,12.5,16.0,0.35,0.058337,1
ACT,3,37,13.3,19.0,0.08,0.076448,0


In [7]:
# Use glob to create categorized lists of Yahoo Finance CSV's
bs_list = glob.glob("*_annual_balance-sheet.csv")
cf_list = glob.glob("*_annual_cash-flow.csv")
af_list = glob.glob("*_annual_financials.csv")
list_of_lists = [bs_list,cf_list,af_list]

In [8]:
# Clean and organize CSV
def step_1(frame):
    frame_df = frame.set_index('name')
    
    # Delete TTM column as this data may be more recent than date of IPO
    if 'ttm' in frame_df.columns: 
        frame_df.drop(columns = 'ttm',inplace=True)
    else:
        pass
    
    # Keep only latest annual financial data from year prior to IPO
    frame_df = frame_df.iloc[:,0]
    
    # Fill NaN's
    frame_df = pd.DataFrame(frame_df.fillna(0))
    
    # Transpose DataFrame
    frame_df = frame_df.T
    
    # Standardize index
    frame_df = frame_df.reset_index()
    
    #Remove '\t...' from column names
    frame_df.columns = frame_df.columns.str.strip()
    
    return frame_df

In [9]:
#Create dataframe for CSV per respective category
def step_2(list):
    df = pd.DataFrame()
    
    # Loop through each list of YF Financials and clean CSV
    for file in list:
        
        # Read in CSV
        path = Path(file)
        csv = pd.read_csv(path, engine='python')
        
        # Clean CSV using step_1 function
        step = step_1(csv)
        
        # Add data from csv to df
        df = pd.concat([df,step], axis = 0)
        
        # Fill NaNs
        df = df.fillna(0)
        
        # Remove "\t" from column names
        df.columns = df.columns.str.strip()
        
        # Reset index
        df.reset_index(drop=True, inplace=True)
        
        # Drop useless column
        df.drop(columns='index',inplace=True)
    
    # Loop through columns and drop columns that have 0 for more than 20% of the data therein
    for column in df.columns[1:]:
        zeros = (df[column]==0).sum()
        if zeros >= 0.50*len(df.index):
            df.drop(columns = column, inplace = True)
        else:
            pass
    
    return df

In [10]:
# Create dataframe concatenated from respective categorical dataframes
def step_3(list):
    df = pd.DataFrame()
    tickers = []
    
    # Loop through each list of respective YF CSVs
    for list in list_of_lists:
        
        #Clean and concatenate CSV based on category
        frame = step_2(list)
        df = pd.concat([df,frame],axis=1)
    
    # Drop random NaN row?
    df = df.dropna()
    
    #Clean and update index for syntax
    for name in bs_list:
        tickers.append(name[:4])
    df.index = tickers
    
    for x in df.index:
        if x[3] == '_':
            df.rename(index={x:x[:3]},inplace=True)
        elif x[2] == '_':
            df.rename(index={x:x[:2]},inplace=True)
        elif x[1] == '_':
            df.rename(index={x:x[:1]},inplace=True)
        else:
            pass
    
    return df

In [11]:
#Running Functions and Compiling DFs
bz_df = step_3(list_of_lists)
bz_ipo_df = pd.concat([bz_df, ipo_df],axis=1)

In [12]:
bz_ipo_df.head(10)

Unnamed: 0,TotalAssets,CurrentAssets,CashCashEquivalentsAndShortTermInvestments,CashAndCashEquivalents,Receivables,PrepaidAssets,TotalNonCurrentAssets,OtherNonCurrentAssets,TotalLiabilitiesNetMinorityInterest,CurrentLiabilities,PayablesAndAccruedExpenses,Payables,AccountsPayable,CurrentAccruedExpenses,OtherCurrentLiabilities,TotalNonCurrentLiabilitiesNetMinorityInterest,TotalEquityGrossMinorityInterest,StockholdersEquity,CapitalStock,CommonStock,AdditionalPaidInCapital,RetainedEarnings,TotalCapitalization,CommonStockEquity,NetTangibleAssets,WorkingCapital,InvestedCapital,TangibleBookValue,ShareIssued,OrdinarySharesNumber,NetPPE,GrossPPE,Properties,MachineryFurnitureEquipment,OtherProperties,Leases,AccumulatedDepreciation,GoodwillAndOtherIntangibleAssets,OtherIntangibleAssets,CurrentDebtAndCapitalLeaseObligation,CurrentDebt,CurrentDeferredLiabilities,CurrentDeferredRevenue,LongTermDebtAndCapitalLeaseObligation,LongTermDebt,LongTermCapitalLeaseObligation,OtherNonCurrentLiabilities,CapitalLeaseObligations,TotalDebt,AccountsReceivable,NonCurrentDeferredLiabilities,GainsLossesNotAffectingRetainedEarnings,OperatingCashFlow,CashFlowFromContinuingOperatingActivities,NetIncomeFromContinuingOperations,OperatingGainsLosses,StockBasedCompensation,ChangeInWorkingCapital,ChangeInReceivables,ChangeInPrepaidAssets,ChangeInPayablesAndAccruedExpense,ChangeInPayable,ChangeInAccountPayable,ChangeInAccruedExpense,FinancingCashFlow,CashFlowFromContinuingFinancingActivities,EndCashPosition,ChangesInCash,BeginningCashPosition,InterestPaidSupplementalData,IssuanceOfCapitalStock,FreeCashFlow,DepreciationAmortizationDepletion,DepreciationAndAmortization,ChangeInOtherWorkingCapital,InvestingCashFlow,CashFlowFromContinuingInvestingActivities,NetPPEPurchaseAndSale,PurchaseOfPPE,NetIssuancePaymentsOfDebt,NetLongTermDebtIssuance,LongTermDebtIssuance,LongTermDebtPayments,CapitalExpenditure,IssuanceOfDebt,RepaymentOfDebt,OtherNonCashItems,ChangesInAccountReceivables,NetOtherFinancingCharges,TotalRevenue,OperatingRevenue,OperatingExpense,SellingGeneralAndAdministration,GeneralAndAdministrativeExpense,OtherGandA,ResearchAndDevelopment,OperatingIncome,NetNonOperatingInterestIncomeExpense,OtherIncomeExpense,PretaxIncome,NetIncomeCommonStockholders,NetIncome,NetIncomeIncludingNoncontrollingInterests,NetIncomeContinuousOperations,DilutedNIAvailtoComStockholders,BasicEPS,DilutedEPS,BasicAverageShares,DilutedAverageShares,TotalOperatingIncomeAsReported,TotalExpenses,NetIncomeFromContinuingAndDiscontinuedOperation,NormalizedIncome,NetInterestIncome,EBIT,NetIncomeFromContinuingOperationNetMinorityInterest,TotalUnusualItemsExcludingGoodwill,TotalUnusualItems,NormalizedEBITDA,TaxRateForCalcs,TaxEffectOfUnusualItems,InterestExpenseNonOperating,OtherNonOperatingIncomeExpenses,InterestExpense,ReconciledDepreciation,TaxProvision,CostOfRevenue,GrossProfit,ReconciledCostOfRevenue,Industry,Offer Date,Shares (millions),Offer Price,Return,SPY 90D Return,10% Returns?
ABOS,44429000,44429000,43777000,43777000,109000,91000,0,0,63020000,6367000,864000,531000,531000,333000,5503000,56653000,-18591000,-18591000,0,0,8374000,-26965000,-18591000,-18591000,-18591000,38062000,-18591000,-18591000,38651795,38651795,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,-7450000,-7450000,-7325000,-586000,154000,307000,-79000,53000,189000,308000,308000,-119000,44675000,44675000,43777000,37225000,6552000,0,44675000,-7450000,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1436000,1436000,9348000,1351000,1351000,1351000,7997000,-7912000,1000,586000,-7325000,-7911000,-7325000,-7325000,-7325000,-7911000,-0.28,-0.28,28651796,28651796,-7912000,9348000,-7325000,-7911000,1000,-7912000,-7325000,586000,586000,-8498000,0.0,0,0,0,0,0,0,0,0,0,4,26,10.0,16.0,0.26,0.068748,1
ABSI,88569000,73234000,69867000,69867000,1594000,1773000,15335000,1950000,177997000,10095000,3685000,2116000,2116000,1569000,0,167902000,-89428000,-89428000,2000,2000,635000,-90065000,-85287000,-89428000,-89428000,63139000,-83752000,-89428000,90375022,90375022,13385000,15730000,0,660000,13054000,2016000,-2345000,188000,188000,3780000,1535000,2630000,2630000,10720000,4141000,6579000,749000,8824000,14500000,0,0,0,-10970000,-10970000,-14353000,824000,420000,1008000,-1372000,-1434000,1956000,903000,903000,1053000,70973000,70973000,71708000,57832000,13876000,508000,69334000,-13151000,1131000,1131000,1943000,-2171000,-2171000,-2171000,-2181000,1639000,1639000,3230000,-1591000,-2181000,3230000,-1591000,0,0,0,4780000,4780000,18081000,5502000,0,0,11448000,-13301000,-634000,-418000,-14353000,-49469000,-14353000,-14353000,-14353000,-49469000,-0.547,-0.547,90375022,90375022,-13301000,18081000,-14353000,-14353000,-634000,-13719000,-14353000,0,0,-12588000,0.0,0,634000,-418000,634000,1131000,0,0,0,0,4,29,12.5,16.0,0.35,0.058337,1
ACT,5652710000,0,0,452794000,46464000,0,0,0,1770899000,0,0,0,0,0,0,0,3881811000,3881811000,1628000,1628000,2368699000,1303106000,4619973000,3881811000,3881811000,0,4619973000,3881811000,162840000,162840000,0,0,0,0,0,0,0,0,0,0,0,0,0,738162000,738162000,0,0,0,738162000,46464000,36811000,208378000,704350000,704350000,370421000,3324000,0,309889000,-5303000,0,0,0,0,0,300298000,300298000,452794000,-132264000,585058000,0,0,704350000,0,0,0,-1136912000,-1136912000,0,0,737651000,737651000,737651000,0,0,737651000,0,6720000,-5303000,0,1029947000,0,0,0,0,0,0,0,-18244000,5575000,472418000,370421000,370421000,370421000,370421000,370421000,2.275,2.275,162840000,162840000,0,634041000,370421000,370421000,0,490662000,370421000,0,0,0,0.216,0,18244000,0,18244000,0,101997000,0,0,0,3,37,13.3,19.0,0.08,0.076448,0
ACXP,3224020,3224020,3175411,3175411,0,48609,0,0,522434,472556,138863,0,0,138863,0,49878,2701586,2701586,16502198,16502198,0,-13800612,2751464,2701586,2701586,2751464,2768089,2701586,9541159,9541159,0,0,0,0,0,0,0,0,0,16625,16625,0,0,49878,49878,0,0,0,66503,0,0,0,-3351558,-3351558,-4600038,781700,1267946,-801166,0,-506,-800660,0,0,0,4043647,4043647,3175411,692089,2483322,0,3977144,-3351558,0,0,0,0,0,0,0,66503,66503,66503,0,0,66503,0,1353813,0,0,0,0,4600038,2397059,2397059,2397059,2202979,-4600038,0,0,-4600038,-4600038,-4600038,-4600038,-4600038,-4600038,-0.74,-0.74,8891338,8891338,0,4600038,-4600038,-4600038,0,-4600038,-4600038,0,0,-4600038,0.0,0,0,0,0,0,0,0,0,0,4,25,2.5,6.0,0.32,0.073989,1
AIP,42736000,28952000,11744000,11744000,15424000,1237000,13784000,1772000,54744000,28226000,7291000,1116000,1116000,6175000,574000,26518000,-12008000,-12008000,18000,18000,3612000,-15607000,-11281000,-12008000,-18094000,726000,-10081000,-18094000,29996570,29996570,5118000,6347000,0,3487000,2760000,100000,-1229000,6086000,3409000,1967000,1200000,18394000,17894000,2806000,727000,2079000,2268000,2846000,4773000,14350000,15014000,-31000,2163000,2163000,-3260000,-1567000,458000,5057000,-6324000,-2608000,3430000,414000,414000,3016000,790000,790000,11744000,-2194000,13938000,65000,0,1509000,935000,935000,11086000,-5147000,-5147000,-647000,-654000,554000,554000,1603000,-1049000,-654000,1603000,-1049000,540000,-6324000,0,31812000,31812000,34098000,17078000,7329000,7329000,17020000,-3777000,-50000,1593000,-2234000,-3260000,-3260000,-3260000,-3260000,-3260000,-0.15,-0.15,22049162,22049162,-3777000,35589000,-3260000,-4422890,-50000,-2184000,-3260000,1593000,1593000,-2842000,0.27,430110,50000,0,50000,935000,1026000,1491000,30321000,1491000,8,43,5.0,14.0,0.3,0.052541,1
AIRS,179610000,11563000,10379000,10379000,0,1184000,168047000,1544000,55934000,9457000,2934000,1095000,1095000,1839000,0,46477000,123676000,123676000,0,0,0,0,155795000,123676000,-18666000,2106000,32519000,-18666000,55359177,55359177,24161000,25260000,0,878000,19008000,5374000,-1099000,142342000,60608000,3290000,400000,3233000,3233000,46477000,32119000,14358000,0,17248000,49767000,0,0,0,13957000,13957000,7577000,0,325000,203000,0,275000,87000,-1019000,-1019000,1106000,-5017000,-5017000,10379000,5251000,5128000,2293000,0,10268000,5641000,5641000,45000,-3689000,-3689000,-3689000,-3689000,-400000,-400000,2500000,-2900000,-3689000,2500000,-2900000,211000,0,-4617000,62766000,62766000,29262000,23621000,0,0,0,10033000,-2456000,0,7577000,5750000,5750000,5750000,5750000,5750000,0.103,0.103,55640154,55640154,10033000,52733000,5750000,5750000,-2456000,10033000,5750000,0,0,15674000,0.241,0,2456000,0,2456000,5641000,1827000,23471000,39295000,23471000,4,43,7.0,11.0,0.44,0.032384,1
AKA,189439000,65486000,26259000,26259000,1183000,4056000,123953000,0,50555000,41245000,25976000,13525000,4689000,12451000,0,9310000,138884000,0,0,0,0,14138000,128901000,128901000,-117355000,24241000,135254000,11546000,126590141,126590141,6598000,7846000,0,1349000,4477000,2020000,-1248000,117355000,29102000,7587000,6353000,4165000,4165000,3262000,0,3262000,144000,4496000,10849000,1183000,5904000,5839000,21712000,21712000,14805000,0,1380000,1673000,-833000,20000,9329000,912000,-2776000,8417000,1240000,1240000,27099000,20573000,5791000,278000,450000,19933000,6762000,6762000,2587000,-2379000,-2379000,-1328000,-1328000,790000,0,0,0,-1779000,0,0,0,-833000,0,215916000,215916000,104261000,104261000,28077000,28077000,0,22140000,0,-485000,21655000,14334000,14334000,14805000,14805000,14334000,0.113,0.113,126590141,126590141,22140000,193776000,14334000,14334000,0,22140000,14334000,0,0,28902000,0.316,0,0,-485000,0,6762000,6850000,89515000,126401000,89515000,1,38,10.0,11.0,-0.09,0.05088,0
ALZN,2912590,2912590,1929270,1929270,16210,614608,0,0,899643,899643,564340,564340,503591,0,0,0,2012947,2012947,6818,6743,33721860,-16832436,2012947,2012872,2012947,2012947,2348175,2012872,84929525,84929525,0,0,0,0,0,0,0,0,0,335303,335303,0,0,0,0,0,0,0,335303,0,0,0,-2712027,-2712027,-5046567,-62418,2032359,-152451,0,260791,-413242,0,0,0,4450097,4450097,1929270,1838985,90285,0,2100000,-2712027,0,0,0,100915,100915,0,0,400192,400192,402110,-1918,0,402110,-1918,517050,0,1949905,0,0,4951888,3641172,3641172,865441,1310716,-4951888,-157097,62418,-5046567,-5046567,-5046567,-5046567,-5046567,-5046567,-0.059,-0.059,84929525,84929525,-4951888,4951888,-5046567,-5108985,-157097,-4887764,-5046567,62418,62418,-4950182,0.0,0,158803,0,158803,0,0,0,0,0,4,24,2.5,5.0,1.7,0.11072,1
AMAM,134021000,93077000,90462000,90462000,982000,626000,40944000,624000,273455000,13260000,4405000,2820000,2820000,1585000,0,260195000,-139434000,-139434000,0,0,6805000,-145553000,-139434000,-139434000,-176263000,79817000,-139434000,-176263000,263686190,263686190,3491000,7807000,0,526000,6986000,295000,-4316000,36829000,36829000,1595000,0,6470000,6470000,1598000,0,1598000,138000,3193000,3193000,428000,4141000,-686000,-20576000,-20576000,-17839000,0,1217000,-7375000,415000,-656000,-3857000,-3247000,-3247000,-610000,95678000,95678000,91278000,74850000,16263000,7000,95678000,-20828000,1943000,1943000,-1280000,-252000,-252000,-252000,-252000,0,0,1339000,-1339000,-252000,1339000,-1339000,1478000,415000,0,13671000,13671000,26786000,6353000,6353000,6353000,20433000,-13115000,27000,-4750000,-17838000,-16543000,-16543000,-17839000,-17839000,-16543000,-0.441,-0.441,37669456,37669456,-13115000,26786000,-16543000,-16543000,27000,-13115000,-16543000,0,0,-11172000,0.0,0,0,-4750000,0,1943000,1000,0,0,0,4,24,7.0,18.0,-0.05,0.087255,0
AMPL,175082000,148679000,117783000,117783000,17396000,6857000,26403000,6898000,242202000,53324000,11633000,4673000,4417000,6960000,894000,188878000,-67120000,-67120000,0,0,37704000,-104824000,-67120000,-67120000,-70075000,95355000,-67120000,-70075000,102444655,102444655,2673000,4692000,0,3878000,0,814000,-2019000,2955000,1955000,0,0,40797000,40797000,0,0,0,1067000,0,0,17396000,0,0,-10392000,-10392000,-24567000,0,16553000,-4319000,-5561000,-1469000,4140000,2414000,2414000,1726000,54245000,54245000,118863000,37945000,80918000,0,49820000,-12600000,1690000,1690000,3867000,-5908000,-5908000,-984000,-984000,0,0,0,0,-2208000,0,0,0,-5561000,-2000,102464000,102464000,95984000,69886000,18067000,18067000,26098000,-24003000,0,269000,-23734000,-28246000,-24567000,-24567000,-24567000,-28246000,-0.3,-0.3,95594997,95594997,0,126467000,-24567000,-24567000,0,-24003000,-24567000,0,0,-22313000,0.27,0,0,269000,0,1690000,833000,30483000,71981000,30483000,8,39,35.4,50.0,0.1,0.052165,1


In [13]:
bz_ipo_df.shape

(249, 136)

In [14]:
bz_ipo_df['Industry'] = bz_ipo_df['Industry'].astype(str)

In [15]:
bz_ipo_df['10% Returns?'] = bz_ipo_df['10% Returns?'].astype(str)

In [16]:
for column in bz_ipo_df.columns:
    bz_ipo_df[column] = bz_ipo_df[column].astype(str)
    bz_ipo_df[column] = bz_ipo_df[column].str.replace(',','').astype(float)

In [17]:
bz_ipo_df.dtypes

TotalAssets                                            float64
CurrentAssets                                          float64
CashCashEquivalentsAndShortTermInvestments             float64
CashAndCashEquivalents                                 float64
Receivables                                            float64
PrepaidAssets                                          float64
TotalNonCurrentAssets                                  float64
OtherNonCurrentAssets                                  float64
TotalLiabilitiesNetMinorityInterest                    float64
CurrentLiabilities                                     float64
PayablesAndAccruedExpenses                             float64
Payables                                               float64
AccountsPayable                                        float64
CurrentAccruedExpenses                                 float64
OtherCurrentLiabilities                                float64
TotalNonCurrentLiabilitiesNetMinorityInterest          

In [18]:
X = bz_ipo_df.iloc[:, 0:-1]
y = bz_ipo_df.iloc[:, -1]

In [19]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=69)

In [20]:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

In [21]:
EPOCHS = 50
BATCH_SIZE = 64
LEARNING_RATE = 0.001

In [22]:
## train data
class TrainData(Dataset):
    
    def __init__(self, X_data, y_data):
        self.X_data = X_data
        self.y_data = y_data
        
    def __getitem__(self, index):
        return self.X_data[index], self.y_data[index]
        
    def __len__ (self):
        return len(self.X_data)


train_data = TrainData(torch.FloatTensor(X_train), 
                       torch.FloatTensor(y_train))
## test data    
class TestData(Dataset):
    
    def __init__(self, X_data):
        self.X_data = X_data
        
    def __getitem__(self, index):
        return self.X_data[index]
        
    def __len__ (self):
        return len(self.X_data)
    

test_data = TestData(torch.FloatTensor(X_test))

In [23]:
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=1)

In [24]:
class BinaryClassification(nn.Module):
    def __init__(self):
        super(BinaryClassification, self).__init__()
        # Number of input features is 135.
        self.layer_1 = nn.Linear(135, 64) 
        self.layer_2 = nn.Linear(64, 64)
        self.layer_out = nn.Linear(64, 1) 
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        self.batchnorm1 = nn.BatchNorm1d(64)
        self.batchnorm2 = nn.BatchNorm1d(64)
        
    def forward(self, inputs):
        x = self.relu(self.layer_1(inputs))
        x = self.batchnorm1(x)
        x = self.relu(self.layer_2(x))
        x = self.batchnorm2(x)
        x = self.dropout(x)
        x = self.layer_out(x)
        
        return x

In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [26]:
model = BinaryClassification()
model.to(device)
print(model)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

BinaryClassification(
  (layer_1): Linear(in_features=135, out_features=64, bias=True)
  (layer_2): Linear(in_features=64, out_features=64, bias=True)
  (layer_out): Linear(in_features=64, out_features=1, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
  (batchnorm1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [27]:
def binary_acc(y_pred, y_test):
    y_pred_tag = torch.round(torch.sigmoid(y_pred))

    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum/y_test.shape[0]
    acc = torch.round(acc * 100)
    
    return acc

In [28]:
model.train()
for e in range(1, EPOCHS+1):
    epoch_loss = 0
    epoch_acc = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        
        y_pred = model(X_batch)
        
        loss = criterion(y_pred, y_batch.unsqueeze(1))
        acc = binary_acc(y_pred, y_batch.unsqueeze(1))
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        

    print(f'Epoch {e+0:03}: | Loss: {epoch_loss/len(train_loader):.5f} | Acc: {epoch_acc/len(train_loader):.3f}')

Epoch 001: | Loss: 0.71084 | Acc: 54.333
Epoch 002: | Loss: 0.61764 | Acc: 67.000
Epoch 003: | Loss: 0.57693 | Acc: 72.000
Epoch 004: | Loss: 0.53527 | Acc: 77.667
Epoch 005: | Loss: 0.50171 | Acc: 78.333
Epoch 006: | Loss: 0.44679 | Acc: 84.667
Epoch 007: | Loss: 0.44266 | Acc: 86.333
Epoch 008: | Loss: 0.40481 | Acc: 87.333
Epoch 009: | Loss: 0.40070 | Acc: 89.667
Epoch 010: | Loss: 0.38183 | Acc: 89.333
Epoch 011: | Loss: 0.34523 | Acc: 91.667
Epoch 012: | Loss: 0.32050 | Acc: 93.333
Epoch 013: | Loss: 0.29966 | Acc: 91.333
Epoch 014: | Loss: 0.29166 | Acc: 94.667
Epoch 015: | Loss: 0.24882 | Acc: 94.333
Epoch 016: | Loss: 0.23909 | Acc: 93.667
Epoch 017: | Loss: 0.22285 | Acc: 94.000
Epoch 018: | Loss: 0.21373 | Acc: 95.000
Epoch 019: | Loss: 0.18645 | Acc: 95.333
Epoch 020: | Loss: 0.16027 | Acc: 97.667
Epoch 021: | Loss: 0.15035 | Acc: 99.333
Epoch 022: | Loss: 0.13281 | Acc: 97.667
Epoch 023: | Loss: 0.14292 | Acc: 98.333
Epoch 024: | Loss: 0.16407 | Acc: 97.333
Epoch 025: | Los

In [29]:
y_pred_list = []
model.eval()
with torch.no_grad():
    for X_batch in test_loader:
        X_batch = X_batch.to(device)
        y_test_pred = model(X_batch)
        y_test_pred = torch.sigmoid(y_test_pred)
        y_pred_tag = torch.round(y_test_pred)
        y_pred_list.append(y_pred_tag.cpu().numpy())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]

In [30]:
y_pred_list

[1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0]

In [31]:
confusion_matrix(y_test, y_pred_list)

array([[30,  4],
       [11, 38]], dtype=int64)

In [32]:
print(classification_report(y_test, y_pred_list))

              precision    recall  f1-score   support

         0.0       0.73      0.88      0.80        34
         1.0       0.90      0.78      0.84        49

    accuracy                           0.82        83
   macro avg       0.82      0.83      0.82        83
weighted avg       0.83      0.82      0.82        83

