In [None]:
import pandas as pd
import numpy as np

class node:
    def __init__(self, idNum, name, acronym, depth):
        self.idNum = idNum
        self.name = name
        self.acronym = acronym
        self.parent_id = -1
        self.depth = depth
        self.children = []
        
    def addChild(self, child_node):
        self.children.append(child_node)
        child_node.parent_id = self.idNum
        
    def __str__(self):
        return f'{self.idNum}: {self.name} ({self.acronym}) - {len(self.children)} child nodes ({[c.idNum for c in self.children]})'

class allenBrainTree:
    
    def __init__(self, path):
        # path: The path to the csv file containing the original tree structure (*/atlas_info_KimRef_FPbasedLabel_v2.7.csv)
        self.baseLayer = []
        self.layer_0_ids = [0,997]
        self.layer_1_ids = {0:[], 997:[8, 1009, 2155, 73, 1024]}
        self.layer_2_ids = {8:[567,512,343], 1009:[967,960,983,1000,991], 2155:[], 73:[81,124,129,140,145,164], 1024:[1032,1040,624]}
        self.layer_3_ids = {567:[688,623], 512:[528,519], 343:[1129,313,1065], 967:[949,840,848,832,911,901,798,933,2222,808,917,813,792,871,2275], 960:[744,752,728], 983:[776,784,896], 1000:[760,877,863,941], 991:[768,824], 81:[98,108], 124:[], 129:[2373], 140:[], 145:[2219,153], 164:[], 1032:[1063,1071], 1040:[1095,1087,1095,1103,1112,1119,3,11,34,43,65], 624:[]}
        self.layer_4_ids = {688:[695,703], 623:[477,803], 528:[645,1073], 519:[989,91,2281], 1129:[549,1097], 313:[2107,339,323,2112,2101,348], 1065:[771,354], 949:[], 840:[1016,21,900,2317], 848:[916,336,117,125,2040], 832:[62,158], 911:[], 901:[93,229], 798:[1131,1116], 933:[2461,413,948], 2222:[], 808:[], 917:[237], 813:[], 792:[932], 871:[293], 2275:[], 744:[], 752:[326,78,1123], 728:[], 776:[956,1108,971,986], 784:[6,924,190,198,2169], 896:[1092,14,2109], 760:[102], 877:[1060], 863:[397], 941:[], 768:[908,940,1099,37,301], 824:[54,341,46,1083], 98:[], 108:[], 2373:[], 2219:[], 153:[], 1063:[], 1071:[], 1095:[], 1087:[], 1095:[], 1103:[], 1112:[], 1119:[], 3:[], 11:[], 34:[], 43:[], 65:[]}
        self.layer_5_ids = {695:[315,698,1089], 703:[16,583,942,131,295,319,780], 477:[485,493,275,278], 803:[1022,1031,835,342,826,809], 645:[912,976,984,928,936,944,951,957,968], 1073:[1007,1056,1025,1033,1041,1049,1064], 989:[2244,2245], 91:[2239,2240,2241,2242], 2281:[846,2247], 549:[2414,864,856], 1097:[157,141,467,290,10671], 2107:[], 339:[2476,4,580,271,874,2163,460,2108], 323:[2147,2131,2181,2154,381,749,246,795,2123,50,67,587,2152,2156,1100,616,214,2136,35,975,115,757,231,2168,66,75,58,615], 2112:[], 2101:[], 348:[374,1052,165], 771:[1132,987,2166,1117], 354:[2246,386,370,379], 1016:[], 21:[2279,665,459], 900:[], 2317:[], 916:[], 336:[], 117:[], 125:[], 2040:[], 62:[], 158:[], 93:[], 229:[705,794], 1131:[], 1116:[], 2461:[], 413:[], 948:[841,506,658,633,482], 237:[], 932:[514,697], 293:[], 326:[812,850,866], 78:[2167], 1123:[553,490], 956:[579], 1108:[], 971:[], 986:[], 6:[], 924:[], 190:[], 198:[], 2169:[], 1092:[], 14:[], 2109:[], 102:[], 1060:[], 397:[], 908:[], 940:[], 1099:[466,530,603,737,618], 37:[], 301:[2394], 54:[], 341:[], 46:[753,690,681,673], 1083:[802,595,611]}
        self.layer_6_ids = {315:[184,500,453,44,1057,677,247,669,31,972,714,95,254,22,541,922,895], 698:[507,151,159,589,814,961,619,631,788,566], 1089:[1080,822], 16:[], 583:[2002,2003], 942:[952,966,2085,2426], 131:[2068,2052,2060], 295:[303,311,451], 319:[327,334], 780:[2086,2087,2088], 485:[672], 493:[56,998,754], 275:[242,310,333,2009], 278:[23,292,536,1105,403], 1022:[], 1031:[], 835:[298,2028], 342:[2371,2027,2013], 826:[904,581], 809:[2308,351,287], 912:[10707,10705], 976:[10710,10708], 984:[10713,10711], 928:[992,1001,1091], 936:[10725,10723], 944:[10728,10726], 951:[10731,10729], 957:[10734,10732], 968:[10737,10735], 1007:[10674,10672], 1056:[10677,10675], 1025:[10683,10681], 1033:[10686,10684], 1041:[10689,10687], 1049:[10692,10690], 1064:[10678,10680], 2244:[], 2245:[], 2239:[], 2240:[], 2241:[], 2242:[], 846:[], 2247:[], 2414:[], 864:[2115,2116,2117,2120,637,406,1044,2158,1008], 856:[138,239,444,571,51,262,1014,958], 157:[390,332,38,30,223], 141:[272,668,830,452,2307,523,763,914,1109,133,347,286,338,2011,689], 467:[88,331,515,980,1004,63,693,2474], 290:[194,226,2379,364,173,470,2063,614,2079,2111,797], 10671:[], 2476:[302,294], 4:[811,820,828,2457,2475], 580:[], 271:[], 874:[], 2163:[], 460:[], 2108:[], 2147:[], 2131:[], 2181:[], 2154:[], 381:[], 749:[2138,2144,2124,2427], 246:[2153], 795:[2141,2142,2143,2148,2162,2114], 2123:[], 50:[], 67:[2129], 587:[], 2152:[], 2156:[], 1100:[215,531,2415,628,634,706,1061,2122], 616:[2192], 214:[2137,2139], 2136:[], 35:[2157], 975:[2130], 115:[2180,2179], 757:[], 231:[], 2168:[], 66:[], 75:[], 58:[], 615:[], 374:[2134,2135], 1052:[2183], 165:[12,100,197,591,872], 1132:[612,7,867,398], 987:[2191,280,880,898,931,2206,2220,318,462,534,574,621,2193], 2166:[], 1117:[2197,679,147,2284,146,2185,2186,238], 2246:[], 386:[207,2274,607,720,903,642,2249,651,429,437,445,2270], 370:[2477,2478,2266,2198,2199,2200,653,2228,661,576,640,2472,839,1048,83,136,106,203,235,395,852,2211,859,938,2252,2469,2269,2468,154,1069,701,765,2267,773,2272,781], 379:[206,230,222], 2279:[], 665:[], 459:[], 705:[], 794:[], 841:[], 506:[], 658:[], 633:[], 482:[], 514:[380,388], 697:[], 812:[85], 850:[], 866:[], 2167:[], 553:[], 490:[404], 579:[], 466:[], 530:[], 603:[], 737:[436], 618:[443,449], 2394:[], 753:[], 690:[], 681:[], 673:[], 802:[], 595:[], 611:[]}
        self.layer_7_ids = {184:[68,667,2325], 500:[985,993], 453:[322,378], 44:[707,747,556,827,1054,1081], 1057:[36,180,148,187,638,662], 677:[897,1106,1010,1058,857,849], 247:[1011,1002,1018], 669:[2170,2096,385], 31:[2345,39,48], 972:[171,195,304,363,84,132], 714:[723,731,738], 95:[2331,2332,2333,104,111,119], 254:[879,886], 22:[2056,2057,2082,2286], 541:[97,1127,234,289,729,786], 922:[335,368,540,692,888], 895:[836,427,988,977,1045], 507:[212,220,228,236,244], 151:[188,196,204,2021], 159:[167,175,160,183,191,199], 589:[597,1034,1042,1050,1059,605,1067,1075,1082], 814:[496,535,360,646,267,2019], 961:[276,284,291,2024], 619:[260,268,1139], 631:[639,647], 788:[408,416,424], 566:[1140,1141,1142], 1080:[375,726,982,19], 822:[909,1084,843,1037,502], 2002:[], 2003:[], 952:[], 966:[], 2085:[], 2426:[], 2068:[], 2052:[], 2060:[], 303:[], 311:[], 451:[], 327:[], 334:[], 2086:[], 2087:[], 2088:[], 672:[2376,2491,2492,2496,2495,2001,2050], 56:[2074,2006,2007], 998:[2012,2372], 754:[481,489,458,465,473,2018], 242:[250,258,266,2008], 310:[], 333:[], 2009:[], 23:[], 292:[], 536:[544,551,559], 1105:[2375], 403:[411,418,426,435], 298:[], 2028:[], 2371:[], 2027:[], 2013:[], 904:[564,2280], 581:[], 2308:[], 351:[359,2359,367,2360,2049], 287:[], 10707:[], 10705:[], 10710:[], 10708:[], 10713:[], 10711:[], 992:[], 1001:[], 1091:[10722,10720], 10725:[], 10723:[], 10728:[], 10726:[], 10731:[], 10729:[], 10734:[], 10732:[], 10737:[], 10735:[], 10674:[], 10672:[], 10677:[], 10675:[], 10683:[], 10681:[], 10686:[], 10684:[], 10689:[], 10687:[], 10692:[], 10690:[], 10678:[], 10680:[], 2115:[], 2116:[], 2117:[], 2120:[], 637:[629,2316,685,709], 406:[422], 1044:[], 2158:[], 1008:[475,170], 138:[218,2094,1020,1029,325], 239:[255,1096,64,1120,1113,2282], 444:[59,2091,362,366], 571:[149,15,2038,2032,181], 51:[189,599,907,575,930], 262:[], 1014:[27,178,321], 958:[483,186,953], 390:[2051,2034], 332:[432], 38:[71,94], 30:[], 223:[2061,2062,2104,2095,2103], 272:[], 668:[], 830:[2076,2077,2078], 452:[], 2307:[], 523:[2016,2023,2306], 763:[], 914:[], 1109:[], 133:[], 347:[], 286:[2030,2031], 338:[], 2011:[], 689:[], 88:[700,708,724,2029], 331:[210,491,525,557], 515:[748,756], 980:[], 1004:[], 63:[439], 693:[2044,769,777,785], 2474:[946], 194:[2309,2065,2033,2105,2081,2048,2473], 226:[], 2379:[], 364:[], 173:[2283,2042], 470:[], 2063:[], 614:[2080], 2079:[], 2111:[], 797:[2125,2043,2053,2054,796,2110], 302:[851,842,834], 294:[26,42,17,10], 811:[], 820:[], 828:[2188,2189,2190], 2457:[], 2475:[], 2138:[], 2144:[], 2124:[], 2427:[], 2153:[], 2141:[], 2142:[], 2143:[], 2148:[], 2162:[], 2114:[], 2129:[], 215:[2099,2100], 531:[], 2415:[], 628:[], 634:[2121], 706:[], 1061:[], 2122:[], 2192:[], 2137:[], 2139:[], 2157:[], 2130:[], 2180:[], 2179:[], 2134:[], 2135:[], 2183:[], 12:[], 100:[2140,2145,2146,2149,2151,2150,2160], 197:[], 591:[], 872:[2173,2174,2175,2176,2177,2178], 612:[82,99,2172,2184], 7:[2203,2202], 867:[123,881,890], 398:[122,2456,105,114,2187], 2191:[], 280:[2459,2460], 880:[2210,2201,2230], 898:[2215,2216,2231,2221], 931:[], 2206:[], 2220:[], 318:[], 462:[], 534:[], 574:[2165], 621:[2204,2453,2194,2454,2455], 2193:[], 2197:[], 679:[137,130], 147:[], 2284:[162,2195], 146:[2205,2217,2182], 2185:[], 2186:[], 238:[], 207:[], 2274:[], 607:[112,560,96,101], 720:[711,1039], 903:[], 642:[], 2249:[], 651:[659,666,674,2250,682,691,2263,2273,2271,2264,2265,2470], 429:[], 437:[], 445:[45], 2270:[], 2477:[], 2478:[], 2266:[], 2198:[], 2199:[], 2200:[], 653:[2229], 2228:[], 661:[2226,2462,2463,2464,2465,2466,2467], 576:[], 640:[], 2472:[135,2277,2471], 839:[], 1048:[2223,2224], 83:[2253,2254,2255,2256,2257,2258,2259,2260,2261,2262], 136:[2225], 106:[], 203:[], 235:[963], 395:[1098,1107], 852:[2214], 2211:[], 859:[], 938:[2248,970,978], 2252:[], 2469:[], 2269:[], 2468:[2251], 154:[161,177,169], 1069:[], 701:[217,209,202,225,2268,2218], 765:[], 2267:[], 773:[], 2272:[], 781:[], 206:[2212], 230:[], 222:[], 380:[], 388:[], 85:[], 404:[], 436:[], 443:[], 449:[]}
        self.layer_8_ids = {68:[], 667:[], 2325:[], 985:[320,943,648,844,882,2017], 993:[656,962,767,1021,1085], 322:[2338,793,329,337,345,2000,2352,2041,369,361], 378:[873,806,1035,1090,862,893], 707:[], 747:[], 556:[], 827:[], 1054:[], 1081:[], 36:[], 180:[], 148:[], 187:[], 638:[], 662:[], 897:[], 1106:[], 1010:[], 1058:[], 857:[], 849:[], 1011:[527,600,678,252,156,243], 1002:[735,251,816,847,954,1005], 1018:[959,755,990,1023,520,598], 2170:[402,409,425], 2096:[2097,394,533], 385:[593,821,721,778,33,305,2098,2133], 2345:[2346,2347,2348,2349,2350], 39:[935,211,1015,919,927], 48:[588,296,772,810,819], 171:[], 195:[], 304:[], 363:[], 84:[], 132:[], 723:[448,412,630,440,488], 731:[484,524,582,620,910], 738:[2320,2321,2322,2323,2324], 2331:[], 2332:[], 2333:[], 104:[996,328,1101,783,831,2020], 111:[120,163,344,314,355], 119:[704,694,800,675,699], 879:[442,434,545,610,274,330], 886:[2132,2066,2067], 2056:[2381,2382,2383,2384,2385,2386], 2057:[2387,2388,2389,2390,2391,2392], 2082:[2402,2403,2404,2405,2406,2407], 2286:[2396,2397,2398,2399,2400,2401], 97:[], 1127:[], 234:[], 289:[], 729:[], 786:[], 335:[], 368:[], 540:[], 692:[], 888:[], 836:[], 427:[], 988:[], 977:[], 1045:[], 212:[], 220:[], 228:[], 236:[], 244:[], 188:[], 196:[], 204:[], 2021:[], 167:[], 175:[], 160:[], 183:[], 191:[], 199:[2318,2319], 597:[], 1034:[], 1042:[], 1050:[], 1059:[], 605:[], 1067:[], 1075:[], 1082:[], 496:[], 535:[], 360:[], 646:[], 267:[], 2019:[], 276:[], 284:[], 291:[], 2024:[2303,2304,2305], 260:[], 268:[], 1139:[], 639:[192,200,208], 647:[655,663], 408:[], 416:[], 424:[], 1140:[], 1141:[], 1142:[], 375:[382,423,463], 726:[10703,10704,632], 982:[], 19:[], 909:[918,926,934,2159,2161], 1084:[10699,10700,10701], 843:[10693,10694,10695], 1037:[10696,10697,10698], 502:[509,518,2440], 2376:[], 2491:[2294,2295,2296,2497], 2492:[2498,2500,2499,2501], 2496:[2493,2494,2490], 2495:[], 2001:[], 2050:[], 2074:[], 2006:[], 2007:[], 2012:[], 2372:[], 481:[], 489:[], 458:[], 465:[], 473:[], 2018:[], 250:[], 258:[], 266:[], 2008:[], 544:[], 551:[], 559:[], 2375:[], 411:[], 418:[], 426:[], 435:[], 564:[2005], 2280:[596,2293,2004], 359:[537,498,513,546,554,529], 2359:[2010,2025], 367:[578,585], 2360:[2015,2014,2022], 2049:[], 10722:[], 10720:[], 629:[], 2316:[], 685:[], 709:[718,733,2361], 422:[], 475:[1072,1079,1088,2126], 170:[], 218:[2069,2070,2071,2127], 2094:[], 1020:[2055,2128], 1029:[], 325:[], 255:[2035,2036], 1096:[1104], 64:[], 1120:[], 1113:[], 2282:[155,2046], 59:[], 2091:[], 362:[617,626,636], 366:[2377,2378], 149:[2089,2090], 15:[], 2038:[2039], 2032:[], 181:[2092,2093,2037], 189:[], 599:[], 907:[2064], 575:[], 930:[], 27:[], 178:[2072,2073], 321:[], 483:[], 186:[2058,2059], 953:[], 2051:[], 2034:[], 432:[], 71:[79,652,660], 94:[55,87], 2061:[], 2062:[], 2104:[], 2095:[], 2103:[], 2076:[], 2077:[], 2078:[], 2016:[], 2023:[], 2306:[], 2030:[], 2031:[], 700:[], 708:[], 724:[2047], 2029:[], 210:[], 491:[2118,2119,732], 525:[1110,1118], 557:[1126,1], 748:[], 756:[], 439:[], 2044:[], 769:[], 777:[], 785:[], 946:[2364,2365,2075,2102], 2309:[], 2065:[], 2033:[], 2105:[], 2081:[], 2048:[], 2473:[], 2283:[], 2042:[], 2080:[], 2125:[], 2043:[], 2053:[], 2054:[], 796:[], 2110:[], 851:[], 842:[], 834:[], 26:[], 42:[], 17:[], 10:[], 2188:[], 2189:[], 2190:[], 2099:[], 2100:[], 2121:[], 2140:[], 2145:[], 2146:[], 2149:[], 2151:[], 2150:[], 2160:[], 2173:[], 2174:[], 2175:[], 2176:[], 2177:[], 2178:[], 82:[], 99:[], 2172:[], 2184:[], 2203:[], 2202:[], 123:[], 881:[860,2458,868,875,883,891,2213], 890:[899], 122:[2207,2208,2209,2227], 2456:[], 105:[], 114:[], 2187:[], 2459:[], 2460:[], 2210:[], 2201:[], 2230:[], 2215:[], 2216:[], 2231:[], 2221:[], 2165:[], 2204:[], 2453:[], 2194:[], 2454:[], 2455:[], 137:[], 130:[], 162:[2366,2367], 2195:[], 2205:[], 2217:[], 2182:[], 112:[], 560:[], 96:[2234,2235,2236], 101:[2238,2237,2243], 711:[2285,2278], 1039:[], 659:[], 666:[], 674:[], 2250:[], 682:[], 691:[], 2263:[], 2273:[], 2271:[], 2264:[], 2265:[], 2470:[], 45:[], 2229:[], 2226:[], 2462:[], 2463:[], 2464:[], 2465:[], 2466:[], 2467:[], 135:[], 2277:[], 2471:[], 2248:[], 970:[], 978:[], 2223:[], 2224:[], 2253:[], 2254:[], 2255:[], 2256:[], 2257:[], 2258:[], 2259:[], 2260:[], 2261:[], 2262:[], 2225:[], 963:[], 1098:[], 1107:[], 2214:[], 2248:[], 970:[], 978:[], 2251:[], 161:[2276], 177:[], 169:[], 217:[], 209:[], 202:[2232,2233], 225:[], 2268:[], 2218:[], 2212:[]}
        self.layer_9_ids = {320:[], 943:[], 648:[], 844:[], 882:[], 2017:[2334,2335,2336,2337,2351], 656:[], 962:[], 767:[], 1021:[], 1085:[], 2338:[2339,2340,2341,2342,2343,2344], 793:[346,865,921,686,719], 329:[981,201,1047,1070,1038,1062], 337:[1030,113,1094,1128,478,510], 345:[878,657,950,974,1102,2], 2000:[2287,2288,2289,2290,2291,2292], 2352:[2353,2354,2355,2356,2357,2358], 2041:[2310,2311,2312,2313,2314,2315], 369:[450,854,577,625,945,1026], 361:[1006,670,1086,1111,9,461], 873:[], 806:[], 1035:[], 1090:[], 862:[], 893:[], 527:[], 600:[], 678:[], 252:[], 156:[], 243:[], 735:[], 251:[], 816:[], 847:[], 954:[], 1005:[], 959:[], 755:[], 990:[], 1023:[], 520:[], 598:[], 402:[1074,905,1114,233,601,649], 409:[421,973,573,613,74,121], 425:[750,269,869,902,377,393], 2097:[2393,2409,2410,2411,2412,2413], 394:[281,1066,401,433,1046,441], 533:[805,41,501,565,257,469], 593:[], 821:[], 721:[], 778:[], 33:[], 305:[], 2098:[2428,2429,2430,2431,2432,2433], 2133:[2434,2435,2436,2437,2438,2439], 2346:[], 2347:[], 2348:[], 2349:[], 2350:[], 935:[], 211:[], 1015:[], 919:[], 927:[], 588:[], 296:[], 772:[], 810:[], 819:[], 448:[], 412:[], 630:[], 440:[], 488:[], 484:[], 524:[], 582:[], 620:[], 910:[], 2320:[], 2321:[], 2322:[], 2323:[], 2324:[], 996:[], 328:[], 1101:[], 783:[], 831:[], 2020:[2326,2327,2328,2329,2330], 120:[], 163:[], 344:[], 314:[], 355:[], 704:[], 694:[], 800:[], 675:[], 699:[], 442:[], 434:[], 545:[], 610:[], 274:[], 330:[], 2132:[542,606,430,687,590,622], 2066:[2416,2417,2418,2419,2420], 2067:[2421,2422,2423,2424,2425], 2381:[], 2382:[], 2383:[], 2384:[], 2385:[], 2386:[], 2387:[], 2388:[], 2389:[], 2390:[], 2391:[], 2392:[], 2402:[], 2403:[], 2404:[], 2405:[], 2406:[], 2407:[], 2396:[], 2397:[], 2398:[], 2399:[], 2400:[], 2401:[], 2318:[], 2319:[], 2303:[], 2304:[], 2305:[], 192:[], 200:[], 208:[], 655:[216,224,232], 663:[240,248,256], 382:[391,399,407,415], 423:[431,438,446,454], 463:[471,479,486,495,504], 10703:[], 10704:[], 632:[], 918:[1121,20,999,715,764,52,92,312,139,387,28,60], 926:[526,543,468,508,2408,664,712,727,550,743], 934:[259,324,371,419,1133], 2159:[2164,2444,2445,2446,2447], 2161:[2448,2449,2450,2451,2452], 10699:[], 10700:[], 10701:[], 10693:[], 10694:[], 10695:[], 10696:[], 10697:[], 10698:[], 509:[829,845,837], 518:[853,870,861], 2440:[2441,2442,2443], 2294:[], 2295:[], 2296:[], 2497:[2395,2297], 2498:[2299,2298,2374,2380], 2500:[2302,2480,2483], 2499:[2300,2301], 2501:[2479,2482,2481,2370], 2493:[2484,2485,2486], 2494:[2487,2489,2488], 2490:[], 2005:[], 596:[], 2293:[], 2004:[], 537:[], 498:[], 513:[], 546:[], 554:[], 529:[], 2010:[], 2025:[], 578:[], 585:[], 2015:[], 2014:[], 2022:[], 718:[], 733:[2362,2363], 2361:[], 1072:[], 1079:[], 1088:[], 2126:[], 2069:[], 2070:[], 2071:[], 2127:[], 2055:[], 2128:[], 2035:[], 2036:[], 1104:[], 155:[], 2046:[], 617:[], 626:[], 636:[], 2377:[], 2378:[], 2089:[], 2090:[], 2039:[], 2092:[], 2093:[], 2037:[], 2064:[], 2072:[], 2073:[], 2058:[], 2059:[], 79:[], 652:[], 660:[], 55:[2026], 87:[], 2047:[], 2118:[], 2119:[], 732:[], 1110:[], 1118:[], 1126:[], 1:[], 2364:[], 2365:[], 2075:[], 2102:[], 860:[], 2458:[], 868:[], 875:[], 883:[], 891:[], 2213:[], 899:[], 2207:[], 2208:[], 2209:[], 2227:[], 2366:[], 2367:[], 2234:[], 2235:[], 2236:[], 2238:[], 2237:[], 2243:[], 2285:[], 2278:[], 2276:[], 2232:[], 2233:[2368,2369]}
            
        self.data = pd.read_csv(path)
        
        # Iterate through Layer 0
        for layer0_id in self.layer_0_ids:
            current_node = self.data.loc[self.data['id'] == layer0_id]
            if current_node.empty:
                continue
            current_node_name = current_node['name'].item()
            current_node_acronym = current_node['acronym'].item()
            l0_node = node(layer0_id, current_node_name, current_node_acronym, 0)
            self.addNode(l0_node)

            # Iterate through Layer 1
            layer_1 = self.layer_1_ids[layer0_id]
            for layer1_id in layer_1:
                current_node = self.data.loc[self.data['id'] == layer1_id]
                if current_node.empty:
                    continue
                current_node_name = current_node['name'].item()
                current_node_acronym = current_node['acronym'].item()
                l1_node = node(layer1_id, current_node_name, current_node_acronym, 1)
                l0_node.addChild(l1_node)

                # Iterate through Layer 2
                layer_2 = self.layer_2_ids[layer1_id]
                for layer2_id in layer_2:
                    current_node = self.data.loc[self.data['id'] == layer2_id]
                    if current_node.empty:
                        continue
                    current_node_name = current_node['name'].item()
                    current_node_acronym = current_node['acronym'].item()
                    l2_node = node(layer2_id, current_node_name, current_node_acronym, 2)
                    l1_node.addChild(l2_node)

                    # Iterate through Layer 3
                    layer_3 = self.layer_3_ids[layer2_id]
                    for layer3_id in layer_3:
                        current_node = self.data.loc[self.data['id'] == layer3_id]
                        if current_node.empty:
                            continue
                        current_node_name = current_node['name'].item()
                        current_node_acronym = current_node['acronym'].item()
                        l3_node = node(layer3_id, current_node_name, current_node_acronym, 3)
                        l2_node.addChild(l3_node)

                        # Iterate through Layer 4
                        layer_4 = self.layer_4_ids[layer3_id]
                        for layer4_id in layer_4:
                            current_node = self.data.loc[self.data['id'] == layer4_id]
                            if current_node.empty:
                                continue
                            current_node_name = current_node['name'].item()
                            current_node_acronym = current_node['acronym'].item()
                            l4_node = node(layer4_id, current_node_name, current_node_acronym, 4)
                            l3_node.addChild(l4_node)

                            # Iterate through Layer 5
                            layer_5 = self.layer_5_ids[layer4_id]
                            for layer5_id in layer_5:
                                current_node = self.data.loc[self.data['id'] == layer5_id]
                                if current_node.empty:
                                    continue
                                current_node_name = current_node['name'].item()
                                current_node_acronym = current_node['acronym'].item()                
                                l5_node = node(layer5_id, current_node_name, current_node_acronym, 5)
                                l4_node.addChild(l5_node)

                                # Iterate through Layer 6
                                layer_6 = self.layer_6_ids[layer5_id]
                                for layer6_id in layer_6:
                                    current_node = self.data.loc[self.data['id'] == layer6_id]
                                    if current_node.empty:
                                        continue
                                    current_node_name = current_node['name'].item()
                                    current_node_acronym = current_node['acronym'].item()
                                    l6_node = node(layer6_id, current_node_name, current_node_acronym, 6)
                                    l5_node.addChild(l6_node)

                                    # Iterate through Layer 7
                                    layer_7 = self.layer_7_ids[layer6_id]
                                    for layer7_id in layer_7:
                                        current_node = self.data.loc[self.data['id'] == layer7_id]
                                        if current_node.empty:
                                            continue
                                        current_node_name = current_node['name'].item()
                                        current_node_acronym = current_node['acronym'].item()
                                        l7_node = node(layer7_id, current_node_name, current_node_acronym, 7)
                                        l6_node.addChild(l7_node)

                                        # Iterate through Layer 8
                                        layer_8 = self.layer_8_ids[layer7_id]
                                        for layer8_id in layer_8:
                                            current_node = self.data.loc[self.data['id'] == layer8_id]
                                            if current_node.empty:
                                                continue
                                            current_node_name = current_node['name'].item()
                                            current_node_acronym = current_node['acronym'].item()
                                            l8_node = node(layer8_id, current_node_name, current_node_acronym, 8)
                                            l7_node.addChild(l8_node)

                                            # Iterate through Layer 9
                                            layer_9 = self.layer_9_ids[layer8_id]
                                            for layer9_id in layer_9:
                                                current_node = self.data.loc[self.data['id'] == layer9_id]
                                                if current_node.empty:
                                                    continue
                                                current_node_name = current_node['name'].item()
                                                current_node_acronym = current_node['acronym'].item()
                                                l9_node = node(layer9_id, current_node_name, current_node_acronym, 9)
                                                l8_node.addChild(l9_node)
    
    # Helper function for constructor which adds nodes to the base layer of the tree
    def addNode(self,node):
        self.baseLayer.append(node)

    # =======================================
    # ===== Functions to print out tree =====
    # =======================================
    def printTreeHelp(self, node, depth):
        space = ' '
        print(f'({depth}) {space*2*depth}{node.idNum}: {node.name} ({node.acronym})')
        
        # If current node has no children, exit function
        if len(node.children) == 0:
            return
        
        for child_node in node.children:
            self.printTreeHelp(child_node,depth+1)
        
    def printTree(self):
        for child in self.baseLayer:
            self.printTreeHelp(child,0)      
            
    # =================================================================================
    # ===== Returns a list of all descendents of the node where idNum == targetID =====
    # =================================================================================
    def allDescHelp(self, out, node, targetID, isDesc, verbose):
        if isDesc:
            if verbose:
                outStr = f'({node.depth}) {node.idNum}: {node.name} ({node.acronym})'
                out.append(outStr)
            else:
                out.append(node.idNum)
            
        if len(node.children) == 0:
            return
        
        for child_node in node.children:
            if child_node.idNum == targetID or isDesc:
                self.allDescHelp(out, child_node, targetID, True, verbose)
            else:
                self.allDescHelp(out, child_node, targetID, False, verbose)            
            
    def allDesc(self, targetID, verbose = False):
        # If verbose
        #   return list of strings describing all descendents
        # else:
        #   return list of integer IDs of all descendents
        out = []
        for child in self.baseLayer:
            if child.idNum == targetID:
                self.allDescHelp(out, child, targetID, True, verbose)
            else:
                self.allDescHelp(out, child, targetID, False, verbose)
        return out
    

    # =================================================================
    # ===== Extract a subtree from allenBrainTree with root idNum =====
    # =================================================================
    def subtreeHelp(self, idNum, node):
        # print(node)
        if node.idNum == idNum:
            return node
        
        if len(node.children) != 0:
            for child_node in node.children:
                outNode = self.subtreeHelp(idNum, child_node)
                if outNode != None and outNode.idNum == idNum:
                    return outNode
    
    def subtree(self, idNum):
        for node in self.baseLayer:
            # print(node.idNum)
            if node.idNum == idNum:
                outNode = node
                break
            elif len(node.children) == 0:
                continue
            else:
                outNode = self.subtreeHelp(idNum, node)
            
        return outNode

In [None]:
label_to_id_csv_path = '/nafs/dtward/dong/upenn_atlas/atlas_info_KimRef_FPbasedLabel_v2.7.csv'
abt = allenBrainTree(label_to_id_csv_path)
abt.allDesc(672, verbose=True)

In [None]:
abt.printTree()

### Important IDs:
- 997: root (root)
- 477: Striatum (STR)
-   485: Striatum dorsal region (STRd)
-     672: Caudate Putamen (CP)
-       2492: Caudateputamen- intermediate (CPi)  

In [None]:
import pandas as pd

# Inputs to function
brain = 'TME09-1'   # The brain for post-processing
target_region = 477 # The ID of the 'highest' region of interest (All descendents of this node will be included)
target_depth = 9    # The desired depth of the new subtree

# Generate list of regions to keep
label_to_id_csv_path = '/nafs/dtward/dong/upenn_atlas/atlas_info_KimRef_FPbasedLabel_v2.7.csv'
abt = allenBrainTree(label_to_id_csv_path)
regionsToKeep = abt.allDesc(target_region, verbose = False) # List of integer IDs for all descendents of target_region
subtreeToKeep = abt.subtree(target_region) # A subtree with the target_region node as the root

# Load soma_label_prob.v08.csv
label_csv_path = f'/home/abenneck/dragonfly_work/dragonfly_outputs/{brain}/dragonfly_joint_outputs/soma_label_prob_v08.csv'
data = pd.read_csv(label_csv_path)
data = data.drop(labels=' ', axis=1)

# Add all missing columns
all_ids = list(pd.read_csv(label_to_id_csv_path)['id'])
data_ids = [int(x) for x in list(data.columns[1:])]
for label in all_ids:
    if label in data_ids:
        continue
    else:
        new_col = pd.DataFrame(np.zeros(len(data)), columns=[f' {label}'])
        data = pd.concat([data, new_col], axis=1)

# Remove columns not in target_region
data_out = data.copy()
for col in list(data.columns[1:]):
    # Drop all columns not in the list regionsToKeep
    if int(col) in regionsToKeep:
        continue
    else:
        data_out = data_out.drop(labels=col, axis=1)
data = data_out

# Recursively eliminate leaf nodes by P(parent) += P(child_1) + ... + P(child_n)
lowest_depth = 8 # (DO NOT CHANGE) A constant which represents the lowest depth in the tree with nodes that have children
if target_depth <= lowest_depth:
    while True:
        data_out = data.copy()
        for col in list(data.columns[1:]):
            current_node = abt.subtree(int(col))
            if current_node.depth == lowest_depth and len(current_node.children) > 0:
                for child in current_node.children:
                    data_out[f' {current_node.idNum}'] = data_out[f' {current_node.idNum}'] + data_out[f' {child.idNum}']
                    data_out = data_out.drop(labels = f' {child.idNum}', axis=1)
            else:
                continue
        data = data_out

        # Break out of loop if desired layer is now the bottommost
        if target_depth == lowest_depth:
            break

        lowest_depth -= 1
    
# Predict which region neuron lies in as the max(columns)
data['pred_region'] = data.idxmax(axis=1,numeric_only=True)

# Compute new sum of each row 
data['unweighted_sum'] = data.sum(axis=1, numeric_only = True)

# If unwieghted_sum == 0, set to 0.1 (Prevents div by 0 error)
orig_usum = data.iloc[:,-1].copy()
temp_usum = data.iloc[:,-1].copy()
for i, val in enumerate(data.iloc[:,-1]):
    if val == 0:
        temp_usum[i] = 0.1
data.iloc[:,-1] = temp_usum

# Redefine each probability P(neuron is in sub-target_region) as a conditional probability P(neuron is in sub-target_region|neuron is in target_region)
data.iloc[:,1:-2] = data.iloc[:,1:-2].div(data.unweighted_sum, axis=0)

# Replace all '0.1' with '0' 
data.iloc[:,-1] = orig_usum

# Append P(neuron is NOT in target_region)
data['out'] = 1 - data['unweighted_sum']

# Save new df as csv
out_path = f'/home/abenneck/dragonfly_work/dragonfly_outputs/{brain}/dragonfly_joint_outputs/soma_label_cond_prob_D{target_depth}.csv'
data.to_csv(out_path, index=False)

In [None]:
idx = 0
for r in regionsToKeep:
    if f' {r}' in list(data.columns[1:]):
        idx+=1
    else:
        print(r)
idx

In [None]:
all_ids = list(pd.read_csv(label_to_id_csv_path)['id'])
data_ids = [int(x) for x in list(data.columns[1:])]
for label in all_ids:
    if label in data_ids:
        continue
    else:
        data[f' {label}'] = np.zeros(len(data))

In [None]:
len(pd.read_csv(f'/dragonfly_work/temp_{target_depth}.csv'))

In [None]:
print(abt.subtree(2497))

In [None]:
dragonfly_work/temp_9_.csv

In [None]:
data.head()

In [None]:
# pd.set_option('display.max_columns', None)
pd.set_option('display.max_columns', 10)

pred_region = 23
data.loc[data['pred_region'] == f' {pred_region}']