@@ -1352,64 +1352,6 @@ <h1>Source code for torch.cuda</h1><div class="highlight"><pre>
13521352
13531353
13541354
1355- < span class ="k "> def</ span > < span class ="nf "> _get_device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ])</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> :</ span >
1356- < span class ="sa "> r</ span > < span class ="sd "> """Return the torch.device type object from the passed in device.</ span >
1357-
1358- < span class ="sd "> Args:</ span >
1359- < span class ="sd "> device (torch.device or int): selected device.</ span >
1360- < span class ="sd "> """</ span >
1361- < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> ):</ span >
1362- < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
1363- < span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="nb "> int</ span > < span class ="p "> ):</ span >
1364- < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> (</ span > < span class ="s1 "> 'cuda'</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
1365- < span class ="k "> return</ span > < span class ="n "> device</ span >
1366-
1367-
1368- < span class ="k "> def</ span > < span class ="nf "> _get_generator</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> Generator</ span > < span class ="p "> :</ span >
1369- < span class ="sa "> r</ span > < span class ="sd "> """Return the CUDA Generator object for the given device.</ span >
1370-
1371- < span class ="sd "> Args:</ span >
1372- < span class ="sd "> device (torch.device): selected device.</ span >
1373- < span class ="sd "> """</ span >
1374-
1375- < span class ="n "> idx</ span > < span class ="o "> =</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> index</ span >
1376- < span class ="k "> if</ span > < span class ="n "> idx</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1377- < span class ="n "> idx</ span > < span class ="o "> =</ span > < span class ="n "> current_device</ span > < span class ="p "> ()</ span >
1378- < span class ="k "> return</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> cuda</ span > < span class ="o "> .</ span > < span class ="n "> default_generators</ span > < span class ="p "> [</ span > < span class ="n "> idx</ span > < span class ="p "> ]</ span >
1379-
1380-
1381- < span class ="k "> def</ span > < span class ="nf "> _set_rng_state_offset</ span > < span class ="p "> (</ span > < span class ="n "> offset</ span > < span class ="p "> :</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="s1 "> 'cuda'</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1382- < span class ="sa "> r</ span > < span class ="sd "> """Sets the random number generator state offset of the specified GPU.</ span >
1383-
1384- < span class ="sd "> Args:</ span >
1385- < span class ="sd "> offset (int): The desired offset</ span >
1386- < span class ="sd "> device (torch.device or int, optional): The device to set the RNG state.</ span >
1387- < span class ="sd "> Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).</ span >
1388- < span class ="sd "> """</ span >
1389- < span class ="n "> final_device</ span > < span class ="o "> =</ span > < span class ="n "> _get_device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
1390-
1391- < span class ="k "> def</ span > < span class ="nf "> cb</ span > < span class ="p "> ():</ span >
1392- < span class ="n "> default_generator</ span > < span class ="o "> =</ span > < span class ="n "> _get_generator</ span > < span class ="p "> (</ span > < span class ="n "> final_device</ span > < span class ="p "> )</ span >
1393- < span class ="n "> default_generator</ span > < span class ="o "> .</ span > < span class ="n "> set_offset</ span > < span class ="p "> (</ span > < span class ="n "> offset</ span > < span class ="p "> )</ span >
1394-
1395- < span class ="n "> _lazy_call</ span > < span class ="p "> (</ span > < span class ="n "> cb</ span > < span class ="p "> )</ span >
1396-
1397- < span class ="k "> def</ span > < span class ="nf "> _get_rng_state_offset</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="s1 "> 'cuda'</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> int</ span > < span class ="p "> :</ span >
1398- < span class ="sa "> r</ span > < span class ="sd "> """Returns the random number generator state offset of the specified GPU.</ span >
1399-
1400- < span class ="sd "> Args:</ span >
1401- < span class ="sd "> device (torch.device or int, optional): The device to return the RNG state offset of.</ span >
1402- < span class ="sd "> Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).</ span >
1403-
1404- < span class ="sd "> .. warning::</ span >
1405- < span class ="sd "> This function eagerly initializes CUDA.</ span >
1406- < span class ="sd "> """</ span >
1407- < span class ="n "> _lazy_init</ span > < span class ="p "> ()</ span >
1408- < span class ="n "> final_device</ span > < span class ="o "> =</ span > < span class ="n "> _get_device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
1409- < span class ="n "> default_generator</ span > < span class ="o "> =</ span > < span class ="n "> _get_generator</ span > < span class ="p "> (</ span > < span class ="n "> final_device</ span > < span class ="p "> )</ span >
1410- < span class ="k "> return</ span > < span class ="n "> default_generator</ span > < span class ="o "> .</ span > < span class ="n "> get_offset</ span > < span class ="p "> ()</ span >
1411-
1412-
14131355< span class ="kn "> from</ span > < span class ="nn "> .memory</ span > < span class ="kn "> import</ span > < span class ="o "> *</ span > < span class ="c1 "> # noqa: F403</ span >
14141356
14151357
0 commit comments